Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 71 additions & 10 deletions gimmik/hip.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,84 @@ class HIPMatMul(MatMul):
basemeta = {'block': (128, 1, 1), 'width': 1, 'shared': 0}

def _kernel_generators(self, dtype, dsize, *, gcn_arch=None, warp_size=64):
# B loading, C streaming kernel
yield ('cstream', {}, {})
max_block_threads = 1024
max_shared = 64*1024

# B streaming, C accumulation kernel
yield ('bstream', {}, {})
def emit(name, args, meta):
block = meta.get('block', self.basemeta['block'])
shared = meta.get('shared', self.basemeta['shared'])
threads = block[0]*block[1]*block[2]

if threads <= max_block_threads and shared <= max_shared:
yield (name, args, meta)

def emit_preload(name, args, meta):
yield from emit(name, args | {'preload': True}, meta)

# Four-way m-split B streaming, C accumulation kernel
ms, bsz, blkx = 4, 24, 64
args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx}
meta = {'block': (blkx, ms, 1), 'shared': 2*bsz*blkx*dsize}
yield ('bstream-msplit', args, meta)
meta = {
'block': (blkx, ms, 1), 'shared': 2*bsz*blkx*dsize,
'desc': f'bstream-msplit/m{ms}-b{bsz}-x{blkx}'
}
yield from emit('bstream-msplit', args, meta)

# Two-way k-split B loading, C streaming kernel
ks, csz, blkx = 2, 24, 64
args = {'ksplit': ks, 'csz': csz, 'blockx': blkx}
meta = {'block': (blkx, ks, 1), 'shared': (ks - 1)*csz*blkx*dsize}
yield ('cstream-ksplit', args, meta)
meta = {
'block': (blkx, ks, 1), 'shared': (ks - 1)*csz*blkx*dsize,
'desc': f'cstream-ksplit/k{ks}-c{csz}-x{blkx}'
}
yield from emit('cstream-ksplit', args, meta)

# Tuned HIP variants
msplits, ksplits = [8, 4], [4, 2]
bsz, csz, blkx = 8, 8, 64
widths = [1]
if self.aligne is not None and self.aligne % 2 == 0:
widths.insert(0, 2)

for width in widths:
wargs = ({'dtype': f'{dtype}{width}', 'width': width}
if width > 1 else {})
wmeta = {'width': width} if width > 1 else {}
wpfx = f'w{width}-' if width > 1 else ''

for ms in msplits:
# m-split B streaming, C accumulation kernel
args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx} | wargs
shared = 2*bsz*blkx*dsize*width
meta = {
'block': (blkx, ms, 1), 'shared': shared,
'desc': f'bstream-msplit/{wpfx}m{ms}-b{bsz}-x{blkx}'
} | wmeta
yield from emit('bstream-msplit', args, meta)

for ms in msplits:
# m-split B streaming, C preloading, C accumulation kernel
args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx} | wargs
shared = 2*bsz*blkx*dsize*width
meta = {
'block': (blkx, ms, 1), 'shared': shared,
'desc': (
f'bstream-msplit-preload-c/'
f'{wpfx}m{ms}-b{bsz}-x{blkx}'
)
} | wmeta
yield from emit_preload('bstream-msplit', args, meta)

for ks in ksplits:
# k-split B loading, C preloading, C streaming kernel
args = {'ksplit': ks, 'csz': csz, 'blockx': blkx} | wargs
shared = (ks - 1)*csz*blkx*dsize*width
meta = {
'block': (blkx, ks, 1), 'shared': shared,
'desc': (
f'cstream-ksplit-preload-c/'
f'{wpfx}k{ks}-c{csz}-x{blkx}'
)
} | wmeta
yield from emit_preload('cstream-ksplit', args, meta)

def _process_meta(self, meta):
if self.n is not None:
Expand Down
68 changes: 65 additions & 3 deletions gimmik/kernels/hip/base.mako
Original file line number Diff line number Diff line change
@@ -1,12 +1,74 @@
% if dtype.endswith('4'):
static inline __device__ ${dtype} make_zero()
inline __device__ ${dtype} operator+(${dtype} a, ${dtype} b)
{ return make_${dtype}(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); }

inline __device__ ${dtype} operator*(${dtype[:-1]} a, ${dtype} b)
{ return make_${dtype}(a*b.x, a*b.y, a*b.z, a*b.w); }

inline __device__ ${dtype} make_zero()
{ return make_${dtype}(0, 0, 0, 0); }
% elif dtype.endswith('2'):
static inline __device__ ${dtype} make_zero()
inline __device__ ${dtype} operator+(${dtype} a, ${dtype} b)
{ return make_${dtype}(a.x + b.x, a.y + b.y); }

inline __device__ ${dtype} operator*(${dtype[:-1]} a, ${dtype} b)
{ return make_${dtype}(a*b.x, a*b.y); }

inline __device__ ${dtype} make_zero()
{ return make_${dtype}(0, 0); }
% else:
static inline __device__ ${dtype} make_zero()
inline __device__ ${dtype} make_zero()
{ return 0; }
% endif

static inline __device__ void
nt_store(${dtype}* p, ${dtype} v)
{
% if dtype.endswith('4'):
__builtin_nontemporal_store(v.x, &p->x);
__builtin_nontemporal_store(v.y, &p->y);
__builtin_nontemporal_store(v.z, &p->z);
__builtin_nontemporal_store(v.w, &p->w);
% elif dtype.endswith('2'):
__builtin_nontemporal_store(v.x, &p->x);
__builtin_nontemporal_store(v.y, &p->y);
% else:
__builtin_nontemporal_store(v, p);
% endif
}

static inline __device__ ${dtype}
nt_load(const ${dtype}* p)
{
% if dtype.endswith('4'):
return make_${dtype}(__builtin_nontemporal_load(&p->x),
__builtin_nontemporal_load(&p->y),
__builtin_nontemporal_load(&p->z),
__builtin_nontemporal_load(&p->w));
% elif dtype.endswith('2'):
return make_${dtype}(__builtin_nontemporal_load(&p->x),
__builtin_nontemporal_load(&p->y));
% else:
return __builtin_nontemporal_load(p);
% endif
}

static inline __device__ void
store_c(${dtype}* p, ${dtype} v)
{
nt_store(p, v);
}

static inline __device__ ${dtype}
load_c(const ${dtype}* p)
{
return nt_load(p);
}

static inline __device__ ${dtype}
load_b(const ${dtype}* p)
{
return nt_load(p);
}

${next.body()}
38 changes: 28 additions & 10 deletions gimmik/kernels/hip/bstream-msplit.mako
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
<%
mx = partition(A, into=msplit, by='rows')
bchunks = chunk(bix, bsz)
preload = context.get('preload', False)
%>

__global__ __launch_bounds__(${blockx*msplit}) void
Expand All @@ -12,7 +13,7 @@ ${kname}(int n,
${dtype}* __restrict__ c, int ldc)
{
% if width > 1:
n = ((n + ${width} - 1) / ${width}) * ${width};
n = (n + ${width} - 1) / ${width};
ldb /= ${width};
ldc /= ${width};
% endif
Expand All @@ -34,9 +35,22 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c)
{
% for kx in bchunks[0]:
% if loop.index % msplit == cid:
bsub[0][${loop.index}][threadIdx.x] = b[i + ${kx}*ldb];
bsub[0][${loop.index}][threadIdx.x] = load_b(&b[i + ${kx}*ldb]);
% endif
% endfor

% if preload and beta != 0:
## Preload C values for active rows owned by this m-split lane
% for j, jx in enumerate(mx[cid]):
% if afix[jx] != -1:
% if beta == 1:
csub[${j}] = load_c(&c[i + ${jx}*ldc]);
% else:
csub[${j}] = ${beta}*load_c(&c[i + ${jx}*ldc]);
% endif
% endif
% endfor
% endif
}
% endfor
__syncthreads();
Expand All @@ -51,36 +65,40 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c)
% if not loop.parent.last:
% for kx in bchunks[bb + 1]:
% if loop.index % msplit == cid:
bsub[${(bb + 1) % 2}][${loop.index}][threadIdx.x] = b[i + ${kx}*ldb];
bsub[${(bb + 1) % 2}][${loop.index}][threadIdx.x] = load_b(&b[i + ${kx}*ldb]);
% endif
% endfor
% endif
## Accumulate our dot products
% for kx in bchunks[bb]:
bv = bsub[${bb % 2}][${loop.index}][threadIdx.x];
% for j, jx in enumerate(A[mcx, kx]):
% if jx != 0 and kx == afix[mcx[j]]:
% if preload and beta != 0 and jx != 0:
csub[${j}] += ${jx}*bv;
% elif jx != 0 and kx == afix[mcx[j]]:
csub[${j}] = ${jx}*bv;
% elif jx != 0:
csub[${j}] += ${jx}*bv;
% endif
## If we're done with this dot product then store to global
% if kx == alix[mcx[j]] and beta == 0:
c[i + ${mcx[j]}*ldc] = csub[${j}];
% if preload and kx == alix[mcx[j]]:
store_c(&c[i + ${mcx[j]}*ldc], csub[${j}]);
% elif kx == alix[mcx[j]] and beta == 0:
store_c(&c[i + ${mcx[j]}*ldc], csub[${j}]);
% elif kx == alix[mcx[j]] and beta == 1:
c[i + ${mcx[j]}*ldc] += csub[${j}];
store_c(&c[i + ${mcx[j]}*ldc], load_c(&c[i + ${mcx[j]}*ldc]) + csub[${j}]);
% elif kx == alix[mcx[j]]:
c[i + ${mcx[j]}*ldc] = csub[${j}] + ${beta}*c[i + ${mcx[j]}*ldc];
store_c(&c[i + ${mcx[j]}*ldc], csub[${j}] + ${beta}*load_c(&c[i + ${mcx[j]}*ldc]));
% endif
% endfor
% endfor
## Handle rows of A which are all zero
% if loop.parent.last:
% for j, jx in enumerate(afix):
% if jx == -1 and j % msplit == cid and beta == 0:
c[i + ${j}*ldc] = make_zero();
store_c(&c[i + ${j}*ldc], make_zero());
% elif jx == -1 and j % msplit == cid and beta != 1:
c[i + ${j}*ldc] *= ${beta};
store_c(&c[i + ${j}*ldc], ${beta}*load_c(&c[i + ${j}*ldc]));
% endif
% endfor
% endif
Expand Down
41 changes: 30 additions & 11 deletions gimmik/kernels/hip/bstream.mako
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
<%inherit file='base'/>

__global__ __launch_bounds__(128) void
<% preload = context.get('preload', False) %>

__global__ __launch_bounds__(${blockx}) void
% if n is None:
${kname}(int n,
const ${dtype}* __restrict__ b, int ldb,
${dtype}* __restrict__ c, int ldc)
{
% if width > 1:
n = ((n + ${width} - 1) / ${width}) * ${width};
n = (n + ${width} - 1) / ${width};
ldb /= ${width};
ldc /= ${width};
% endif
Expand All @@ -24,32 +26,49 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c)
{
${dtype} bv, csub[${m}];

## Iterare through the used rows of B
% if preload and beta != 0:
## Preload C values for rows which will receive a non-zero dot product
% for j, jx in enumerate(afix):
% if jx != -1:
% if beta == 1:
csub[${j}] = load_c(&c[i + ${j}*ldc]);
% else:
csub[${j}] = ${beta}*load_c(&c[i + ${j}*ldc]);
% endif
% endif
% endfor
% endif

## Iterate through the used rows of B
% for kx in bix:
bv = b[i + ${kx}*ldb];
bv = load_b(&b[i + ${kx}*ldb]);
% for j, jx in enumerate(A[:, kx]):
% if jx != 0 and kx == afix[j]:
% if preload and beta != 0 and jx != 0:
csub[${j}] += ${jx}*bv;
% elif jx != 0 and kx == afix[j]:
csub[${j}] = ${jx}*bv;
% elif jx != 0:
csub[${j}] += ${jx}*bv;
% endif
##
% if kx == alix[j] and beta == 0:
c[i + ${j}*ldc] = csub[${j}];
% if preload and kx == alix[j]:
store_c(&c[i + ${j}*ldc], csub[${j}]);
% elif kx == alix[j] and beta == 0:
store_c(&c[i + ${j}*ldc], csub[${j}]);
% elif kx == alix[j] and beta == 1:
c[i + ${j}*ldc] += csub[${j}];
store_c(&c[i + ${j}*ldc], load_c(&c[i + ${j}*ldc]) + csub[${j}]);
% elif kx == alix[j]:
c[i + ${j}*ldc] = csub[${j}] + ${beta}*c[i + ${j}*ldc];
store_c(&c[i + ${j}*ldc], csub[${j}] + ${beta}*load_c(&c[i + ${j}*ldc]));
% endif
% endfor
% endfor

## Handle rows of A which are all zero
% for j, jx in enumerate(afix):
% if jx == -1 and beta == 0:
c[i + ${j}*ldc] = make_zero();
store_c(&c[i + ${j}*ldc], make_zero());
% elif jx == -1 and beta != 1:
c[i + ${j}*ldc] *= ${beta};
store_c(&c[i + ${j}*ldc], ${beta}*load_c(&c[i + ${j}*ldc]));
% endif
% endfor
}
Expand Down
Loading