Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 50 additions & 20 deletions iris_metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ static int bf16_linear_use_graph(int seq_len, int in_dim, int out_dim) {
static id<MTLComputePipelineState> g_transpose_to_heads_bf16_pipeline;
static id<MTLComputePipelineState> g_transpose_from_heads_bf16_pipeline;
static id<MTLComputePipelineState> g_attention_fused_bf16_pipeline;
/* Weight conversion pipeline */
static id<MTLComputePipelineState> g_bf16_to_f16_pipeline;
/* F32 VAE pipelines */
static id<MTLComputePipelineState> g_group_norm_f32_pipeline;
static id<MTLComputePipelineState> g_swish_f32_pipeline;
Expand Down Expand Up @@ -1037,34 +1039,58 @@ static void clear_f16_cache(void) {
}
}

/* Convert bf16 to f16 */
uint16_t *f16_data = malloc(num_elements * sizeof(uint16_t));
if (!f16_data) {
pthread_mutex_unlock(&g_f16_cache_mutex);
return nil;
}
for (size_t i = 0; i < num_elements; i++) {
f16_data[i] = bf16_to_f16(weights[i]);
size_t size = num_elements * sizeof(uint16_t);
id<MTLBuffer> buf = nil;

/* Use GPU kernel for bf16→f16 conversion when available.
* For large weight tensors (e.g. 84M elements for a single block's
* QKV+MLP weights) this is significantly faster than a CPU loop. */
if (g_shaders_initialized && g_bf16_to_f16_pipeline) {
id<MTLBuffer> input_buf = [g_device newBufferWithBytes:weights
length:size
options:MTLResourceStorageModeShared];
id<MTLBuffer> output_buf = [g_device newBufferWithLength:size
options:MTLResourceStorageModeShared];
if (input_buf && output_buf) {
id<MTLCommandBuffer> cmdBuffer = [g_queue commandBuffer];
id<MTLComputeCommandEncoder> encoder = [cmdBuffer computeCommandEncoder];
int n = (int)num_elements;
[encoder setComputePipelineState:g_bf16_to_f16_pipeline];
[encoder setBuffer:input_buf offset:0 atIndex:0];
[encoder setBuffer:output_buf offset:0 atIndex:1];
[encoder setBytes:&n length:sizeof(int) atIndex:2];
NSUInteger threads = 256;
NSUInteger groups = (num_elements + threads - 1) / threads;
[encoder dispatchThreadgroups:MTLSizeMake(groups, 1, 1)
threadsPerThreadgroup:MTLSizeMake(threads, 1, 1)];
[encoder endEncoding];
[cmdBuffer commit];
[cmdBuffer waitUntilCompleted];
buf = output_buf;
}
}

size_t size = num_elements * sizeof(uint16_t);
/* Fallback: CPU conversion */
if (!buf) {
uint16_t *f16_data = malloc(size);
if (!f16_data) {
pthread_mutex_unlock(&g_f16_cache_mutex);
return nil;
}
for (size_t i = 0; i < num_elements; i++)
f16_data[i] = bf16_to_f16(weights[i]);
buf = [g_device newBufferWithBytes:f16_data
length:size
options:MTLResourceStorageModeShared];
free(f16_data);
}

/* Cache is full - just create buffer without caching */
/* Cache is full - return without caching */
if (g_f16_cache_count >= F16_WEIGHT_CACHE_SIZE) {
id<MTLBuffer> buf = [g_device newBufferWithBytes:f16_data
length:size
options:MTLResourceStorageModeShared];
free(f16_data);
pthread_mutex_unlock(&g_f16_cache_mutex);
return buf;
}

/* Create and cache */
id<MTLBuffer> buf = [g_device newBufferWithBytes:f16_data
length:size
options:MTLResourceStorageModeShared];
free(f16_data);

g_f16_cache[g_f16_cache_count].cpu_ptr = weights;
g_f16_cache[g_f16_cache_count].gpu_buffer = buf;
g_f16_cache[g_f16_cache_count].size = size;
Expand Down Expand Up @@ -3471,6 +3497,10 @@ int iris_metal_init_shaders(void) {
if (func) {
g_upsample_nearest_2x_f32_pipeline = [g_device newComputePipelineStateWithFunction:func error:&error];
}
func = [g_shader_library newFunctionWithName:@"bf16_to_f16_convert"];
if (func) {
g_bf16_to_f16_pipeline = [g_device newComputePipelineStateWithFunction:func error:&error];
}

g_shaders_initialized = 1;
if (iris_verbose)
Expand Down
29 changes: 29 additions & 0 deletions iris_shaders.metal
Original file line number Diff line number Diff line change
Expand Up @@ -2134,3 +2134,32 @@ kernel void upsample_nearest_2x_f32(

out[c * out_spatial + oy * out_w + ox] = x[c * in_h * in_w + iy * in_w + ix];
}

/* Convert bf16 to f16 for MPS compatibility.
* bf16: sign(1) + exp(8) + mant(7)
* f16: sign(1) + exp(5) + mant(10)
* Converts in parallel on GPU — much faster than a CPU loop for large
* weight tensors (e.g. 3072*27648 = 84M elements for a single block). */
inline ushort bf16_to_f16_val(ushort bf16) {
uint sign = (bf16 >> 15) & 0x1;
int exp = (bf16 >> 7) & 0xFF; /* bf16 exponent (bias 127) */
uint mant = bf16 & 0x7F; /* bf16 mantissa (7 bits) */

if (exp == 0) return ushort(sign << 15); /* zero/denormal → zero */
if (exp == 0xFF) return ushort((sign << 15) | 0x7C00 | (mant != 0 ? 0x200 : 0)); /* inf/NaN */

int new_exp = exp - 127 + 15; /* rebias: bf16 bias=127, f16 bias=15 */
if (new_exp <= 0) return ushort(sign << 15); /* underflow → zero */
if (new_exp >= 31) return ushort((sign << 15) | 0x7C00); /* overflow → inf */

return ushort((sign << 15) | (uint(new_exp) << 10) | (mant << 3)); /* expand 7→10 mantissa bits */
}

kernel void bf16_to_f16_convert(
device const ushort *input [[buffer(0)]],
device ushort *output [[buffer(1)]],
constant int &n [[buffer(2)]],
uint gid [[thread_position_in_grid]]
) {
if (gid < uint(n)) output[gid] = bf16_to_f16_val(input[gid]);
}