Skip to content

vulkan: add v_dot2_f32_f16 support in matrix-matrix multiplication and Flash Attention#24123

Open
0cc4m wants to merge 4 commits into
masterfrom
0cc4m/vulkan-valve-dot2
Open

vulkan: add v_dot2_f32_f16 support in matrix-matrix multiplication and Flash Attention#24123
0cc4m wants to merge 4 commits into
masterfrom
0cc4m/vulkan-valve-dot2

Conversation

@0cc4m
Copy link
Copy Markdown
Contributor

@0cc4m 0cc4m commented Jun 4, 2026

Overview

This PR adds basic support for the Vulkan extension VK_VALVE_shader_mixed_float_dot_product. The background to this is that AMD Vega20, Navi14 and RDNA2+ GPUs have fp16 dot2 instructions for machine learning acceleration that are not emitted by the shader compiler due to numerical inconsistencies. The extension allows shaders to manually emit them.

This PR adds support for the v_dot2_f32_f16 fp16 packed dot product with fp32 accumulator in matrix-matrix multiplications and Flash Attention. This is a good improvement for AMD GPUs with this instruction, but without coopmat support.

AMD Radeon Pro VII (Vega20) Benchmarks
Test Before After Δ%
MUL_MAT(type_a=f32,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 6.220 8.170 +31.35%
MUL_MAT(type_a=f16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 4.490 9.840 +119.15%
MUL_MAT(type_a=bf16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 5.380 5.320 -1.12%
MUL_MAT(type_a=q4_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 15.370 15.230 -0.91%
MUL_MAT(type_a=q4_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 14.950 14.800 -1.00%
MUL_MAT(type_a=q5_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 9.980 9.900 -0.80%
MUL_MAT(type_a=q5_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 9.890 9.820 -0.71%
MUL_MAT(type_a=q8_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 13.190 13.150 -0.30%
MUL_MAT(type_a=q1_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 4.880 10.730 +119.88%
MUL_MAT(type_a=mxfp4,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 14.690 14.610 -0.54%
MUL_MAT(type_a=nvfp4,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 4.740 10.230 +115.82%
MUL_MAT(type_a=q2_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 7.240 7.220 -0.28%
MUL_MAT(type_a=q3_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 10.360 10.310 -0.48%
MUL_MAT(type_a=q4_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 11.590 11.550 -0.35%
MUL_MAT(type_a=q5_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 8.190 8.180 -0.12%
MUL_MAT(type_a=q6_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 7.950 7.940 -0.13%
MUL_MAT(type_a=iq2_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 4.690 9.900 +111.09%
MUL_MAT(type_a=iq2_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 4.670 9.810 +110.06%
MUL_MAT(type_a=iq2_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 4.610 9.640 +109.11%
MUL_MAT(type_a=iq3_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 4.500 9.050 +101.11%
MUL_MAT(type_a=iq1_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 4.700 9.990 +112.55%
MUL_MAT(type_a=iq1_m,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 4.640 9.760 +110.34%
MUL_MAT(type_a=iq4_nl,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 4.750 10.220 +115.16%
MUL_MAT(type_a=iq3_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 4.500 8.990 +99.78%
MUL_MAT(type_a=iq4_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 4.560 9.510 +108.55%
MUL_MAT_ID(type_a=f32,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=512,k=2048) 2.640 3.000 +13.64%
MUL_MAT_ID(type_a=f16,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=512,k=2048) 2.370 4.700 +98.31%
MUL_MAT_ID(type_a=q4_0,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=512,k=2048) 6.320 6.280 -0.63%
MUL_MAT_ID(type_a=q8_0,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=512,k=2048) 5.660 5.700 +0.71%
MUL_MAT_ID(type_a=q4_K,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=512,k=2048) 6.030 6.040 +0.17%
MUL_MAT_ID(type_a=q6_K,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=512,k=2048) 4.400 4.420 +0.45%
MUL_MAT_ID(type_a=iq2_xs,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=512,k=2048) 2.400 5.060 +110.83%
MUL_MAT_ID(type_a=f32,type_b=f32,n_mats=32,n_used=4,b=0,m=1792,n=512,k=2048) 3.400 3.470 +2.06%
MUL_MAT_ID(type_a=f16,type_b=f32,n_mats=32,n_used=4,b=0,m=1792,n=512,k=2048) 3.270 5.760 +76.15%
MUL_MAT_ID(type_a=q4_0,type_b=f32,n_mats=32,n_used=4,b=0,m=1792,n=512,k=2048) 8.900 8.430 -5.28%
MUL_MAT_ID(type_a=q8_0,type_b=f32,n_mats=32,n_used=4,b=0,m=1792,n=512,k=2048) 7.950 7.640 -3.90%
MUL_MAT_ID(type_a=q4_K,type_b=f32,n_mats=32,n_used=4,b=0,m=1792,n=512,k=2048) 7.940 8.030 +1.13%
MUL_MAT_ID(type_a=q6_K,type_b=f32,n_mats=32,n_used=4,b=0,m=1792,n=512,k=2048) 6.110 6.190 +1.31%
MUL_MAT_ID(type_a=iq2_xs,type_b=f32,n_mats=32,n_used=4,b=0,m=1792,n=512,k=2048) 3.080 6.640 +115.58%
MUL_MAT_ID(type_a=mxfp4,type_b=f32,n_mats=32,n_used=4,b=0,m=2880,n=512,k=2880) 8.210 8.420 +2.56%
Test Before After Δ%
FLASH_ATTN_EXT(hsk=72,hsv=72,nh=16,nr23=[1,1],kv=5776,nb=5776,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=f16,type_V=f16,permute=[0,1,2,3]) 2.740 2.550 -6.93%
FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr23=[8,1],kv=7680,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=f16,type_V=f16,permute=[0,1,2,3]) 1.410 1.490 +5.67%
FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr23=[8,1],kv=7680,nb=4,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=f16,type_V=f16,permute=[0,1,2,3]) 1.890 2.080 +10.05%
FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr23=[8,1],kv=7680,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=q4_0,type_V=q4_0,permute=[0,1,2,3]) 1.350 1.350 +0.00%
FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr23=[8,1],kv=7680,nb=512,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=q4_0,type_V=q4_0,permute=[0,1,2,3]) 2.930 2.880 -1.71%
FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr23=[8,1],kv=7680,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=q8_0,type_V=q8_0,permute=[0,1,2,3]) 1.380 1.390 +0.72%
FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr23=[8,1],kv=7680,nb=512,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=q8_0,type_V=q8_0,permute=[0,1,2,3]) 3.150 3.160 +0.32%
FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr23=[1,1],kv=4096,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=f16,type_V=f16,permute=[0,1,2,3]) 0.307 0.311 +1.19%
FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr23=[4,1],kv=4096,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=f16,type_V=f16,permute=[0,1,2,3]) 0.822 0.890 +8.29%
FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr23=[1,1],kv=4096,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=f16,type_V=f16,permute=[0,1,2,3]) 0.446 0.456 +2.06%
FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr23=[4,1],kv=4096,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=f16,type_V=f16,permute=[0,1,2,3]) 1.220 1.280 +4.92%
FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr23=[1,1],kv=8192,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=f16,type_V=f16,permute=[0,1,2,3]) 0.420 0.415 -1.24%
FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr23=[4,1],kv=8192,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=f16,type_V=f16,permute=[0,1,2,3]) 0.997 1.110 +11.33%
FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr23=[1,1],kv=8192,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=f16,type_V=f16,permute=[0,1,2,3]) 0.542 0.542 -0.14%
FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr23=[4,1],kv=8192,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=f16,type_V=f16,permute=[0,1,2,3]) 1.380 1.450 +5.07%
FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr23=[1,1],kv=16384,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=f16,type_V=f16,permute=[0,1,2,3]) 0.499 0.481 -3.75%
FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr23=[4,1],kv=16384,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=f16,type_V=f16,permute=[0,1,2,3]) 1.280 1.450 +13.28%
FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr23=[1,1],kv=16384,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=f16,type_V=f16,permute=[0,1,2,3]) 0.591 0.590 -0.02%
FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr23=[4,1],kv=16384,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=f16,type_V=f16,permute=[0,1,2,3]) 1.550 1.640 +5.81%
model size params ngl fa mmap test t/s (before) t/s (after) diff
llama 1B F16 2.05 GiB 1.10 B -1 1 0 pp512 2188.04 ± 7.71 4112.38 ± 13.61 +87.9%
llama 1B F16 2.05 GiB 1.10 B -1 1 0 tg128 227.67 ± 1.46 224.17 ± 0.67 -1.5%
llama 1B F16 2.05 GiB 1.10 B -1 1 0 pp512 @ d4096 1596.55 ± 10.42 2596.67 ± 5.88 +62.6%
llama 1B F16 2.05 GiB 1.10 B -1 1 0 tg128 @ d4096 191.41 ± 0.83 193.39 ± 0.86 +1.0%
llama 1B F16 2.05 GiB 1.10 B -1 1 0 pp512 @ d8192 1228.05 ± 8.12 1865.52 ± 16.46 +51.9%
llama 1B F16 2.05 GiB 1.10 B -1 1 0 tg128 @ d8192 175.61 ± 0.61 179.68 ± 0.34 +2.3%
llama 8B Q4_0 4.33 GiB 8.03 B -1 1 0 pp512 846.17 ± 2.87 828.88 ± 1.90 -2.0%
llama 8B Q4_0 4.33 GiB 8.03 B -1 1 0 tg128 101.36 ± 0.16 102.01 ± 0.29 +0.6%
llama 8B Q4_0 4.33 GiB 8.03 B -1 1 0 pp512 @ d4096 346.27 ± 0.94 449.96 ± 3.69 +29.9%
llama 8B Q4_0 4.33 GiB 8.03 B -1 1 0 tg128 @ d4096 87.77 ± 0.11 88.57 ± 0.11 +0.9%
llama 8B Q4_0 4.33 GiB 8.03 B -1 1 0 pp512 @ d8192 190.72 ± 1.93 275.47 ± 1.97 +44.4%
llama 8B Q4_0 4.33 GiB 8.03 B -1 1 0 tg128 @ d8192 79.34 ± 0.02 79.22 ± 0.94 -0.2%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B -1 1 0 pp512 755.22 ± 5.48 790.23 ± 9.42 +4.6%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B -1 1 0 tg128 71.77 ± 0.05 72.30 ± 0.01 +0.7%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B -1 1 0 pp512 @ d4096 379.66 ± 0.80 438.63 ± 1.63 +15.5%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B -1 1 0 tg128 @ d4096 52.90 ± 0.01 53.14 ± 0.01 +0.5%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B -1 1 0 pp512 @ d8192 251.56 ± 0.33 286.02 ± 0.75 +13.7%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B -1 1 0 tg128 @ d8192 42.19 ± 0.10 42.30 ± 0.02 +0.3%
llama 8B IQ4_XS - 4.25 bpw 4.13 GiB 8.03 B -1 1 0 pp512 325.59 ± 0.34 633.87 ± 1.11 +94.7%
llama 8B IQ4_XS - 4.25 bpw 4.13 GiB 8.03 B -1 1 0 tg128 61.86 ± 0.06 61.70 ± 0.05 -0.3%
llama 8B IQ4_XS - 4.25 bpw 4.13 GiB 8.03 B -1 1 0 pp512 @ d4096 207.74 ± 0.77 388.60 ± 2.32 +87.1%
llama 8B IQ4_XS - 4.25 bpw 4.13 GiB 8.03 B -1 1 0 tg128 @ d4096 56.53 ± 0.05 56.69 ± 0.02 +0.3%
llama 8B IQ4_XS - 4.25 bpw 4.13 GiB 8.03 B -1 1 0 pp512 @ d8192 138.80 ± 1.06 251.68 ± 4.89 +81.3%
llama 8B IQ4_XS - 4.25 bpw 4.13 GiB 8.03 B -1 1 0 tg128 @ d8192 52.49 ± 0.02 52.89 ± 0.06 +0.8%

Requirements

  • I have read and agree with the contributing guidelines
  • AI usage disclosure: YES, Claude wrote the code, I reviewed and tested it.

@0cc4m 0cc4m requested a review from a team as a code owner June 4, 2026 12:42
@github-actions github-actions Bot added Vulkan Issues specific to the Vulkan backend ggml changes relating to the ggml tensor library for machine learning labels Jun 4, 2026
Comment thread ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp Outdated
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Vulkan Issues specific to the Vulkan backend

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants