Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 64 additions & 48 deletions ds4_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5972,15 +5972,17 @@ __global__ static void router_select_kernel(
uint32_t hash_rows,
uint32_t n_tokens,
int has_bias,
int hash_mode) {
int hash_mode,
uint32_t n_expert,
float weight_scale) {
uint32_t t = blockIdx.x;
if (t >= n_tokens || threadIdx.x != 0) return;
const float *log = logits + (uint64_t)t * 256;
float *prob = probs + (uint64_t)t * 256;
const float *log = logits + (uint64_t)t * n_expert;
float *prob = probs + (uint64_t)t * n_expert;
int32_t *sel = selected + (uint64_t)t * 6;
float *w = weights + (uint64_t)t * 6;

for (int i = 0; i < 256; i++) prob[i] = sqrtf(softplus_dev(log[i]));
for (uint32_t i = 0; i < n_expert; i++) prob[i] = sqrtf(softplus_dev(log[i]));

if (hash_mode) {
int32_t tok = tokens ? tokens[t] : token_scalar;
Expand All @@ -5989,12 +5991,12 @@ __global__ static void router_select_kernel(
for (int i = 0; i < 6; i++) sel[i] = row[i];
} else {
for (int i = 0; i < 6; i++) sel[i] = -1;
for (int i = 0; i < 256; i++) {
for (uint32_t i = 0; i < n_expert; i++) {
float score = prob[i] + (has_bias ? bias[i] : 0.0f);
for (int j = 0; j < 6; j++) {
if (sel[j] < 0 || score > prob[sel[j]] + (has_bias ? bias[sel[j]] : 0.0f)) {
for (int k = 5; k > j; k--) sel[k] = sel[k - 1];
sel[j] = i;
sel[j] = (int32_t)i;
break;
}
}
Expand All @@ -6004,12 +6006,12 @@ __global__ static void router_select_kernel(
float sum = 0.0f;
for (int i = 0; i < 6; i++) {
int e = sel[i];
float v = (e >= 0 && e < 256) ? prob[e] : 0.0f;
float v = (e >= 0 && (uint32_t)e < n_expert) ? prob[e] : 0.0f;
w[i] = v;
sum += v;
}
sum = fmaxf(sum, 6.103515625e-5f);
for (int i = 0; i < 6; i++) w[i] = w[i] / sum * 1.5f;
for (int i = 0; i < 6; i++) w[i] = w[i] / sum * weight_scale;
}

__global__ static void router_select_parallel_kernel(
Expand All @@ -6024,15 +6026,17 @@ __global__ static void router_select_parallel_kernel(
uint32_t hash_rows,
uint32_t n_tokens,
int has_bias,
int hash_mode) {
int hash_mode,
uint32_t n_expert,
float weight_scale) {
uint32_t t = blockIdx.x;
uint32_t i = threadIdx.x;
if (t >= n_tokens || i >= 256u) return;
const float *log = logits + (uint64_t)t * 256;
float *prob = probs + (uint64_t)t * 256;
if (t >= n_tokens || i >= n_expert) return;
const float *log = logits + (uint64_t)t * n_expert;
float *prob = probs + (uint64_t)t * n_expert;
int32_t *sel = selected + (uint64_t)t * 6;
float *w = weights + (uint64_t)t * 6;
__shared__ float sprob[256];
__shared__ float sprob[512];

const float p = sqrtf(softplus_dev(log[i]));
sprob[i] = p;
Expand All @@ -6047,12 +6051,12 @@ __global__ static void router_select_parallel_kernel(
for (int j = 0; j < 6; j++) sel[j] = row[j];
} else {
for (int j = 0; j < 6; j++) sel[j] = -1;
for (int e = 0; e < 256; e++) {
for (uint32_t e = 0; e < n_expert; e++) {
float score = sprob[e] + (has_bias ? bias[e] : 0.0f);
for (int j = 0; j < 6; j++) {
if (sel[j] < 0 || score > sprob[sel[j]] + (has_bias ? bias[sel[j]] : 0.0f)) {
for (int k = 5; k > j; k--) sel[k] = sel[k - 1];
sel[j] = e;
sel[j] = (int32_t)e;
break;
}
}
Expand All @@ -6062,12 +6066,12 @@ __global__ static void router_select_parallel_kernel(
float sum = 0.0f;
for (int j = 0; j < 6; j++) {
int e = sel[j];
float v = (e >= 0 && e < 256) ? sprob[e] : 0.0f;
float v = (e >= 0 && (uint32_t)e < n_expert) ? sprob[e] : 0.0f;
w[j] = v;
sum += v;
}
sum = fmaxf(sum, 6.103515625e-5f);
for (int j = 0; j < 6; j++) w[j] = w[j] / sum * 1.5f;
for (int j = 0; j < 6; j++) w[j] = w[j] / sum * weight_scale;
}

__device__ __forceinline__ static bool router_score_better(float av, uint32_t ai, float bv, uint32_t bi) {
Expand All @@ -6086,22 +6090,26 @@ __global__ static void router_select_warp_topk_kernel(
uint32_t hash_rows,
uint32_t n_tokens,
int has_bias,
int hash_mode) {
int hash_mode,
uint32_t n_expert,
float weight_scale) {
const uint32_t lane = threadIdx.x;
const uint32_t row_in_block = threadIdx.y;
const uint32_t t = blockIdx.x * blockDim.y + row_in_block;
if (t >= n_tokens || lane >= 32u) return;

const float *log = logits + (uint64_t)t * 256u;
float *prob = probs + (uint64_t)t * 256u;
/* Each of the 32 warp lanes owns n_expert/32 experts. per_lane is 8 for
* 256 experts (Flash) and 12 for 384 (PRO); capped at 16 (512 experts). */
const uint32_t per_lane = n_expert >> 5;
const float *log = logits + (uint64_t)t * n_expert;
float *prob = probs + (uint64_t)t * n_expert;
int32_t *sel = selected + (uint64_t)t * 6u;
float *w = weights + (uint64_t)t * 6u;
__shared__ float sprob[4][256];
float local_prob[8];
float local_score[8];
__shared__ float sprob[4][512];
float local_prob[16];
float local_score[16];

#pragma unroll
for (uint32_t j = 0; j < 8u; j++) {
for (uint32_t j = 0; j < per_lane; j++) {
const uint32_t e = lane + j * 32u;
const float p = sqrtf(softplus_dev(log[e]));
local_prob[j] = p;
Expand All @@ -6121,13 +6129,13 @@ __global__ static void router_select_warp_topk_kernel(
for (uint32_t j = 0; j < 6u; j++) {
const int32_t e = row[j];
sel[j] = e;
const float v = (e >= 0 && e < 256) ? sprob[row_in_block][(uint32_t)e] : 0.0f;
const float v = (e >= 0 && (uint32_t)e < n_expert) ? sprob[row_in_block][(uint32_t)e] : 0.0f;
w[j] = v;
sum += v;
}
sum = fmaxf(sum, 6.103515625e-5f);
#pragma unroll
for (uint32_t j = 0; j < 6u; j++) w[j] = w[j] / sum * 1.5f;
for (uint32_t j = 0; j < 6u; j++) w[j] = w[j] / sum * weight_scale;
}
return;
}
Expand All @@ -6139,8 +6147,7 @@ __global__ static void router_select_warp_topk_kernel(
float best_score = -INFINITY;
float best_prob = 0.0f;
uint32_t best_idx = UINT32_MAX;
#pragma unroll
for (uint32_t j = 0; j < 8u; j++) {
for (uint32_t j = 0; j < per_lane; j++) {
const uint32_t e = lane + j * 32u;
const float s = local_score[j];
if (router_score_better(s, e, best_score, best_idx)) {
Expand All @@ -6160,8 +6167,7 @@ __global__ static void router_select_warp_topk_kernel(
best_idx = other_idx;
}
}
#pragma unroll
for (uint32_t j = 0; j < 8u; j++) {
for (uint32_t j = 0; j < per_lane; j++) {
const uint32_t e = lane + j * 32u;
if (e == best_idx) local_score[j] = -INFINITY;
}
Expand All @@ -6181,7 +6187,7 @@ __global__ static void router_select_warp_topk_kernel(
}
sum = fmaxf(sum, 6.103515625e-5f);
#pragma unroll
for (uint32_t j = 0; j < 6u; j++) w[j] = w[j] / sum * 1.5f;
for (uint32_t j = 0; j < 6u; j++) w[j] = w[j] / sum * weight_scale;
}
}

Expand Down Expand Up @@ -9530,14 +9536,14 @@ extern "C" int ds4_gpu_directional_steering_project_tensor(
}
extern "C" int ds4_gpu_router_select_tensor(ds4_gpu_tensor *selected, ds4_gpu_tensor *weights, ds4_gpu_tensor *probs, const void *model_map, uint64_t model_size, uint64_t bias_offset, uint64_t hash_offset, uint32_t hash_rows, uint32_t token, uint32_t n_expert, uint32_t n_expert_used, float expert_weight_scale, uint32_t n_expert_groups, uint32_t n_group_used, bool has_bias, bool hash_mode, const ds4_gpu_tensor *logits) {
if (!selected || !weights || !probs || !logits || !model_map || n_expert_groups > 1u || n_group_used > 0u) return 0;
if (n_expert != 256u || n_expert_used != 6u || fabsf(expert_weight_scale - 1.5f) > 1.0e-6f) return 0;
if (n_expert == 0u || n_expert > 512u || (n_expert & 31u) != 0u || n_expert_used != 6u) return 0;
int32_t tok = (int32_t)token;
int ok = 1;
const float *bias = NULL;
const int32_t *hash = NULL;
if (ok && has_bias && !hash_mode) {
if (bias_offset > model_size || model_size - bias_offset < 256u * sizeof(float)) ok = 0;
else bias = (const float *)cuda_model_range_ptr(model_map, bias_offset, 256u * sizeof(float), "router_bias");
if (bias_offset > model_size || model_size - bias_offset < (uint64_t)n_expert * sizeof(float)) ok = 0;
else bias = (const float *)cuda_model_range_ptr(model_map, bias_offset, (uint64_t)n_expert * sizeof(float), "router_bias");
if (!bias) ok = 0;
}
if (ok && hash_mode) {
Expand All @@ -9552,35 +9558,39 @@ extern "C" int ds4_gpu_router_select_tensor(ds4_gpu_tensor *selected, ds4_gpu_te
dim3 block(32, 4, 1);
router_select_warp_topk_kernel<<<1, block>>>((int32_t *)selected->ptr, (float *)weights->ptr, (float *)probs->ptr,
bias, hash, (const float *)logits->ptr, NULL, tok, hash_rows, 1,
has_bias && !hash_mode, hash_mode);
has_bias && !hash_mode, hash_mode, n_expert, expert_weight_scale);
} else if (getenv("DS4_CUDA_NO_PARALLEL_ROUTER_SELECT") == NULL) {
router_select_parallel_kernel<<<1, 256>>>((int32_t *)selected->ptr, (float *)weights->ptr, (float *)probs->ptr,
router_select_parallel_kernel<<<1, n_expert>>>((int32_t *)selected->ptr, (float *)weights->ptr, (float *)probs->ptr,
bias, hash, (const float *)logits->ptr, NULL, tok, hash_rows, 1,
has_bias && !hash_mode, hash_mode);
has_bias && !hash_mode, hash_mode, n_expert, expert_weight_scale);
} else {
router_select_kernel<<<1, 1>>>((int32_t *)selected->ptr, (float *)weights->ptr, (float *)probs->ptr,
bias, hash, (const float *)logits->ptr, NULL, tok, hash_rows, 1,
has_bias && !hash_mode, hash_mode);
has_bias && !hash_mode, hash_mode, n_expert, expert_weight_scale);
}
ok = cuda_ok(cudaGetLastError(), "router_select launch");
}
return ok;
}
extern "C" int ds4_gpu_router_select_batch_tensor(ds4_gpu_tensor *selected, ds4_gpu_tensor *weights, ds4_gpu_tensor *probs, const void *model_map, uint64_t model_size, uint64_t bias_offset, uint64_t hash_offset, uint32_t hash_rows, uint32_t n_expert_groups, uint32_t n_group_used, bool has_bias, bool hash_mode, const ds4_gpu_tensor *logits, const ds4_gpu_tensor *tokens, uint32_t n_expert, uint32_t n_expert_used, float expert_weight_scale, uint32_t n_tokens) {
if (n_expert != 256u || n_expert_used != 6u || fabsf(expert_weight_scale - 1.5f) > 1.0e-6f) return 0;
/* Routed-expert top-6 select. Generalized from the original Flash-only
* (256 experts, scale 1.5) path to any n_expert that is a multiple of 32
* up to 512 (warp owns n_expert/32 each), with the model's routed weight
* scale. Covers DeepSeek-V4 Flash (256/1.5) and PRO (384/2.5). */
if (n_expert == 0u || n_expert > 512u || (n_expert & 31u) != 0u || n_expert_used != 6u) return 0;
if (!selected || !weights || !probs || !logits || !tokens || !model_map || n_tokens == 0 ||
n_expert_groups > 1u || n_group_used > 0u ||
logits->bytes < (uint64_t)n_tokens * 256u * sizeof(float) ||
probs->bytes < (uint64_t)n_tokens * 256u * sizeof(float) ||
logits->bytes < (uint64_t)n_tokens * n_expert * sizeof(float) ||
probs->bytes < (uint64_t)n_tokens * n_expert * sizeof(float) ||
selected->bytes < (uint64_t)n_tokens * 6u * sizeof(int32_t) ||
weights->bytes < (uint64_t)n_tokens * 6u * sizeof(float)) {
return 0;
}
const float *bias = NULL;
const int32_t *hash = NULL;
if (has_bias && !hash_mode) {
if (bias_offset > model_size || model_size - bias_offset < 256u * sizeof(float)) return 0;
bias = (const float *)cuda_model_range_ptr(model_map, bias_offset, 256u * sizeof(float), "router_bias");
if (bias_offset > model_size || model_size - bias_offset < (uint64_t)n_expert * sizeof(float)) return 0;
bias = (const float *)cuda_model_range_ptr(model_map, bias_offset, (uint64_t)n_expert * sizeof(float), "router_bias");
if (!bias) return 0;
}
if (hash_mode) {
Expand All @@ -9603,9 +9613,11 @@ extern "C" int ds4_gpu_router_select_batch_tensor(ds4_gpu_tensor *selected, ds4_
hash_rows,
n_tokens,
has_bias && !hash_mode,
hash_mode);
hash_mode,
n_expert,
expert_weight_scale);
} else if (getenv("DS4_CUDA_NO_PARALLEL_ROUTER_SELECT") == NULL) {
router_select_parallel_kernel<<<n_tokens, 256>>>((int32_t *)selected->ptr,
router_select_parallel_kernel<<<n_tokens, n_expert>>>((int32_t *)selected->ptr,
(float *)weights->ptr,
(float *)probs->ptr,
bias,
Expand All @@ -9616,7 +9628,9 @@ extern "C" int ds4_gpu_router_select_batch_tensor(ds4_gpu_tensor *selected, ds4_
hash_rows,
n_tokens,
has_bias && !hash_mode,
hash_mode);
hash_mode,
n_expert,
expert_weight_scale);
} else {
router_select_kernel<<<n_tokens, 1>>>((int32_t *)selected->ptr,
(float *)weights->ptr,
Expand All @@ -9629,7 +9643,9 @@ extern "C" int ds4_gpu_router_select_batch_tensor(ds4_gpu_tensor *selected, ds4_
hash_rows,
n_tokens,
has_bias && !hash_mode,
hash_mode);
hash_mode,
n_expert,
expert_weight_scale);
}
return cuda_ok(cudaGetLastError(), "router_select launch");
}
Expand Down