diff --git a/ds4_cuda.cu b/ds4_cuda.cu index 188b341ad..0dbea7f60 100644 --- a/ds4_cuda.cu +++ b/ds4_cuda.cu @@ -5972,15 +5972,17 @@ __global__ static void router_select_kernel( uint32_t hash_rows, uint32_t n_tokens, int has_bias, - int hash_mode) { + int hash_mode, + uint32_t n_expert, + float weight_scale) { uint32_t t = blockIdx.x; if (t >= n_tokens || threadIdx.x != 0) return; - const float *log = logits + (uint64_t)t * 256; - float *prob = probs + (uint64_t)t * 256; + const float *log = logits + (uint64_t)t * n_expert; + float *prob = probs + (uint64_t)t * n_expert; int32_t *sel = selected + (uint64_t)t * 6; float *w = weights + (uint64_t)t * 6; - for (int i = 0; i < 256; i++) prob[i] = sqrtf(softplus_dev(log[i])); + for (uint32_t i = 0; i < n_expert; i++) prob[i] = sqrtf(softplus_dev(log[i])); if (hash_mode) { int32_t tok = tokens ? tokens[t] : token_scalar; @@ -5989,12 +5991,12 @@ __global__ static void router_select_kernel( for (int i = 0; i < 6; i++) sel[i] = row[i]; } else { for (int i = 0; i < 6; i++) sel[i] = -1; - for (int i = 0; i < 256; i++) { + for (uint32_t i = 0; i < n_expert; i++) { float score = prob[i] + (has_bias ? bias[i] : 0.0f); for (int j = 0; j < 6; j++) { if (sel[j] < 0 || score > prob[sel[j]] + (has_bias ? bias[sel[j]] : 0.0f)) { for (int k = 5; k > j; k--) sel[k] = sel[k - 1]; - sel[j] = i; + sel[j] = (int32_t)i; break; } } @@ -6004,12 +6006,12 @@ __global__ static void router_select_kernel( float sum = 0.0f; for (int i = 0; i < 6; i++) { int e = sel[i]; - float v = (e >= 0 && e < 256) ? prob[e] : 0.0f; + float v = (e >= 0 && (uint32_t)e < n_expert) ? prob[e] : 0.0f; w[i] = v; sum += v; } sum = fmaxf(sum, 6.103515625e-5f); - for (int i = 0; i < 6; i++) w[i] = w[i] / sum * 1.5f; + for (int i = 0; i < 6; i++) w[i] = w[i] / sum * weight_scale; } __global__ static void router_select_parallel_kernel( @@ -6024,15 +6026,17 @@ __global__ static void router_select_parallel_kernel( uint32_t hash_rows, uint32_t n_tokens, int has_bias, - int hash_mode) { + int hash_mode, + uint32_t n_expert, + float weight_scale) { uint32_t t = blockIdx.x; uint32_t i = threadIdx.x; - if (t >= n_tokens || i >= 256u) return; - const float *log = logits + (uint64_t)t * 256; - float *prob = probs + (uint64_t)t * 256; + if (t >= n_tokens || i >= n_expert) return; + const float *log = logits + (uint64_t)t * n_expert; + float *prob = probs + (uint64_t)t * n_expert; int32_t *sel = selected + (uint64_t)t * 6; float *w = weights + (uint64_t)t * 6; - __shared__ float sprob[256]; + __shared__ float sprob[512]; const float p = sqrtf(softplus_dev(log[i])); sprob[i] = p; @@ -6047,12 +6051,12 @@ __global__ static void router_select_parallel_kernel( for (int j = 0; j < 6; j++) sel[j] = row[j]; } else { for (int j = 0; j < 6; j++) sel[j] = -1; - for (int e = 0; e < 256; e++) { + for (uint32_t e = 0; e < n_expert; e++) { float score = sprob[e] + (has_bias ? bias[e] : 0.0f); for (int j = 0; j < 6; j++) { if (sel[j] < 0 || score > sprob[sel[j]] + (has_bias ? bias[sel[j]] : 0.0f)) { for (int k = 5; k > j; k--) sel[k] = sel[k - 1]; - sel[j] = e; + sel[j] = (int32_t)e; break; } } @@ -6062,12 +6066,12 @@ __global__ static void router_select_parallel_kernel( float sum = 0.0f; for (int j = 0; j < 6; j++) { int e = sel[j]; - float v = (e >= 0 && e < 256) ? sprob[e] : 0.0f; + float v = (e >= 0 && (uint32_t)e < n_expert) ? sprob[e] : 0.0f; w[j] = v; sum += v; } sum = fmaxf(sum, 6.103515625e-5f); - for (int j = 0; j < 6; j++) w[j] = w[j] / sum * 1.5f; + for (int j = 0; j < 6; j++) w[j] = w[j] / sum * weight_scale; } __device__ __forceinline__ static bool router_score_better(float av, uint32_t ai, float bv, uint32_t bi) { @@ -6086,22 +6090,26 @@ __global__ static void router_select_warp_topk_kernel( uint32_t hash_rows, uint32_t n_tokens, int has_bias, - int hash_mode) { + int hash_mode, + uint32_t n_expert, + float weight_scale) { const uint32_t lane = threadIdx.x; const uint32_t row_in_block = threadIdx.y; const uint32_t t = blockIdx.x * blockDim.y + row_in_block; if (t >= n_tokens || lane >= 32u) return; - const float *log = logits + (uint64_t)t * 256u; - float *prob = probs + (uint64_t)t * 256u; + /* Each of the 32 warp lanes owns n_expert/32 experts. per_lane is 8 for + * 256 experts (Flash) and 12 for 384 (PRO); capped at 16 (512 experts). */ + const uint32_t per_lane = n_expert >> 5; + const float *log = logits + (uint64_t)t * n_expert; + float *prob = probs + (uint64_t)t * n_expert; int32_t *sel = selected + (uint64_t)t * 6u; float *w = weights + (uint64_t)t * 6u; - __shared__ float sprob[4][256]; - float local_prob[8]; - float local_score[8]; + __shared__ float sprob[4][512]; + float local_prob[16]; + float local_score[16]; - #pragma unroll - for (uint32_t j = 0; j < 8u; j++) { + for (uint32_t j = 0; j < per_lane; j++) { const uint32_t e = lane + j * 32u; const float p = sqrtf(softplus_dev(log[e])); local_prob[j] = p; @@ -6121,13 +6129,13 @@ __global__ static void router_select_warp_topk_kernel( for (uint32_t j = 0; j < 6u; j++) { const int32_t e = row[j]; sel[j] = e; - const float v = (e >= 0 && e < 256) ? sprob[row_in_block][(uint32_t)e] : 0.0f; + const float v = (e >= 0 && (uint32_t)e < n_expert) ? sprob[row_in_block][(uint32_t)e] : 0.0f; w[j] = v; sum += v; } sum = fmaxf(sum, 6.103515625e-5f); #pragma unroll - for (uint32_t j = 0; j < 6u; j++) w[j] = w[j] / sum * 1.5f; + for (uint32_t j = 0; j < 6u; j++) w[j] = w[j] / sum * weight_scale; } return; } @@ -6139,8 +6147,7 @@ __global__ static void router_select_warp_topk_kernel( float best_score = -INFINITY; float best_prob = 0.0f; uint32_t best_idx = UINT32_MAX; - #pragma unroll - for (uint32_t j = 0; j < 8u; j++) { + for (uint32_t j = 0; j < per_lane; j++) { const uint32_t e = lane + j * 32u; const float s = local_score[j]; if (router_score_better(s, e, best_score, best_idx)) { @@ -6160,8 +6167,7 @@ __global__ static void router_select_warp_topk_kernel( best_idx = other_idx; } } - #pragma unroll - for (uint32_t j = 0; j < 8u; j++) { + for (uint32_t j = 0; j < per_lane; j++) { const uint32_t e = lane + j * 32u; if (e == best_idx) local_score[j] = -INFINITY; } @@ -6181,7 +6187,7 @@ __global__ static void router_select_warp_topk_kernel( } sum = fmaxf(sum, 6.103515625e-5f); #pragma unroll - for (uint32_t j = 0; j < 6u; j++) w[j] = w[j] / sum * 1.5f; + for (uint32_t j = 0; j < 6u; j++) w[j] = w[j] / sum * weight_scale; } } @@ -9530,14 +9536,14 @@ extern "C" int ds4_gpu_directional_steering_project_tensor( } extern "C" int ds4_gpu_router_select_tensor(ds4_gpu_tensor *selected, ds4_gpu_tensor *weights, ds4_gpu_tensor *probs, const void *model_map, uint64_t model_size, uint64_t bias_offset, uint64_t hash_offset, uint32_t hash_rows, uint32_t token, uint32_t n_expert, uint32_t n_expert_used, float expert_weight_scale, uint32_t n_expert_groups, uint32_t n_group_used, bool has_bias, bool hash_mode, const ds4_gpu_tensor *logits) { if (!selected || !weights || !probs || !logits || !model_map || n_expert_groups > 1u || n_group_used > 0u) return 0; - if (n_expert != 256u || n_expert_used != 6u || fabsf(expert_weight_scale - 1.5f) > 1.0e-6f) return 0; + if (n_expert == 0u || n_expert > 512u || (n_expert & 31u) != 0u || n_expert_used != 6u) return 0; int32_t tok = (int32_t)token; int ok = 1; const float *bias = NULL; const int32_t *hash = NULL; if (ok && has_bias && !hash_mode) { - if (bias_offset > model_size || model_size - bias_offset < 256u * sizeof(float)) ok = 0; - else bias = (const float *)cuda_model_range_ptr(model_map, bias_offset, 256u * sizeof(float), "router_bias"); + if (bias_offset > model_size || model_size - bias_offset < (uint64_t)n_expert * sizeof(float)) ok = 0; + else bias = (const float *)cuda_model_range_ptr(model_map, bias_offset, (uint64_t)n_expert * sizeof(float), "router_bias"); if (!bias) ok = 0; } if (ok && hash_mode) { @@ -9552,26 +9558,30 @@ extern "C" int ds4_gpu_router_select_tensor(ds4_gpu_tensor *selected, ds4_gpu_te dim3 block(32, 4, 1); router_select_warp_topk_kernel<<<1, block>>>((int32_t *)selected->ptr, (float *)weights->ptr, (float *)probs->ptr, bias, hash, (const float *)logits->ptr, NULL, tok, hash_rows, 1, - has_bias && !hash_mode, hash_mode); + has_bias && !hash_mode, hash_mode, n_expert, expert_weight_scale); } else if (getenv("DS4_CUDA_NO_PARALLEL_ROUTER_SELECT") == NULL) { - router_select_parallel_kernel<<<1, 256>>>((int32_t *)selected->ptr, (float *)weights->ptr, (float *)probs->ptr, + router_select_parallel_kernel<<<1, n_expert>>>((int32_t *)selected->ptr, (float *)weights->ptr, (float *)probs->ptr, bias, hash, (const float *)logits->ptr, NULL, tok, hash_rows, 1, - has_bias && !hash_mode, hash_mode); + has_bias && !hash_mode, hash_mode, n_expert, expert_weight_scale); } else { router_select_kernel<<<1, 1>>>((int32_t *)selected->ptr, (float *)weights->ptr, (float *)probs->ptr, bias, hash, (const float *)logits->ptr, NULL, tok, hash_rows, 1, - has_bias && !hash_mode, hash_mode); + has_bias && !hash_mode, hash_mode, n_expert, expert_weight_scale); } ok = cuda_ok(cudaGetLastError(), "router_select launch"); } return ok; } extern "C" int ds4_gpu_router_select_batch_tensor(ds4_gpu_tensor *selected, ds4_gpu_tensor *weights, ds4_gpu_tensor *probs, const void *model_map, uint64_t model_size, uint64_t bias_offset, uint64_t hash_offset, uint32_t hash_rows, uint32_t n_expert_groups, uint32_t n_group_used, bool has_bias, bool hash_mode, const ds4_gpu_tensor *logits, const ds4_gpu_tensor *tokens, uint32_t n_expert, uint32_t n_expert_used, float expert_weight_scale, uint32_t n_tokens) { - if (n_expert != 256u || n_expert_used != 6u || fabsf(expert_weight_scale - 1.5f) > 1.0e-6f) return 0; + /* Routed-expert top-6 select. Generalized from the original Flash-only + * (256 experts, scale 1.5) path to any n_expert that is a multiple of 32 + * up to 512 (warp owns n_expert/32 each), with the model's routed weight + * scale. Covers DeepSeek-V4 Flash (256/1.5) and PRO (384/2.5). */ + if (n_expert == 0u || n_expert > 512u || (n_expert & 31u) != 0u || n_expert_used != 6u) return 0; if (!selected || !weights || !probs || !logits || !tokens || !model_map || n_tokens == 0 || n_expert_groups > 1u || n_group_used > 0u || - logits->bytes < (uint64_t)n_tokens * 256u * sizeof(float) || - probs->bytes < (uint64_t)n_tokens * 256u * sizeof(float) || + logits->bytes < (uint64_t)n_tokens * n_expert * sizeof(float) || + probs->bytes < (uint64_t)n_tokens * n_expert * sizeof(float) || selected->bytes < (uint64_t)n_tokens * 6u * sizeof(int32_t) || weights->bytes < (uint64_t)n_tokens * 6u * sizeof(float)) { return 0; @@ -9579,8 +9589,8 @@ extern "C" int ds4_gpu_router_select_batch_tensor(ds4_gpu_tensor *selected, ds4_ const float *bias = NULL; const int32_t *hash = NULL; if (has_bias && !hash_mode) { - if (bias_offset > model_size || model_size - bias_offset < 256u * sizeof(float)) return 0; - bias = (const float *)cuda_model_range_ptr(model_map, bias_offset, 256u * sizeof(float), "router_bias"); + if (bias_offset > model_size || model_size - bias_offset < (uint64_t)n_expert * sizeof(float)) return 0; + bias = (const float *)cuda_model_range_ptr(model_map, bias_offset, (uint64_t)n_expert * sizeof(float), "router_bias"); if (!bias) return 0; } if (hash_mode) { @@ -9603,9 +9613,11 @@ extern "C" int ds4_gpu_router_select_batch_tensor(ds4_gpu_tensor *selected, ds4_ hash_rows, n_tokens, has_bias && !hash_mode, - hash_mode); + hash_mode, + n_expert, + expert_weight_scale); } else if (getenv("DS4_CUDA_NO_PARALLEL_ROUTER_SELECT") == NULL) { - router_select_parallel_kernel<<>>((int32_t *)selected->ptr, + router_select_parallel_kernel<<>>((int32_t *)selected->ptr, (float *)weights->ptr, (float *)probs->ptr, bias, @@ -9616,7 +9628,9 @@ extern "C" int ds4_gpu_router_select_batch_tensor(ds4_gpu_tensor *selected, ds4_ hash_rows, n_tokens, has_bias && !hash_mode, - hash_mode); + hash_mode, + n_expert, + expert_weight_scale); } else { router_select_kernel<<>>((int32_t *)selected->ptr, (float *)weights->ptr, @@ -9629,7 +9643,9 @@ extern "C" int ds4_gpu_router_select_batch_tensor(ds4_gpu_tensor *selected, ds4_ hash_rows, n_tokens, has_bias && !hash_mode, - hash_mode); + hash_mode, + n_expert, + expert_weight_scale); } return cuda_ok(cudaGetLastError(), "router_select launch"); }