Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
d9bf478
wip: add CUDA/cuBLAS backend for NVIDIA GPU acceleration
ServeurpersoCom Jan 22, 2026
5627ac2
wip: fix CUDA initialization and linking
ServeurpersoCom Jan 22, 2026
ec8ab62
wip: first working version, fix cuBLAS row-major to column-major conv…
ServeurpersoCom Jan 22, 2026
2c8c7bc
fix: use caller-provided lda/ldb/ldc in cuBLAS sgemm
ServeurpersoCom Jan 22, 2026
aa39882
perf: weight cache - 26% faster (5.9s vs 8.0s denoising)
ServeurpersoCom Jan 22, 2026
d60a79f
wip: Scratch Buffers - Reusable GPU memory for activations
ServeurpersoCom Jan 22, 2026
8b61cd7
feat(cuda): GPU tensor pool for single_block
ServeurpersoCom Jan 23, 2026
d04ab1e
feat(cuda): GPU batched attention - 53% total speedup from baseline
ServeurpersoCom Jan 23, 2026
839995b
feat(cuda): GPU joint attention for double blocks - 60% total speedup
ServeurpersoCom Jan 23, 2026
02301df
perf(cuda): tensor pool for attention - reduce malloc overhead
ServeurpersoCom Jan 23, 2026
fa8dc19
feat(cuda): full GPU double_block_forward - 64% total speedup
ServeurpersoCom Jan 23, 2026
bdc3642
feat(cuda): GPU conv2d for VAE decoder - 82% faster VAE
ServeurpersoCom Jan 23, 2026
67fad05
feat(cuda): GPU RoPE + VAE conv2d - 66% total speedup (2.96x)
ServeurpersoCom Jan 23, 2026
d42e685
feat(cuda): Qwen3 causal attention on GPU - 42% faster text encoding
ServeurpersoCom Jan 23, 2026
34b2143
fix(cuda): 64-bit indexing in im2col for 1024x1024 VAE decode
ServeurpersoCom Jan 23, 2026
c12ccaa
doc: update
ServeurpersoCom Jan 23, 2026
b45e756
test: add CUDA test runner with GPU-appropriate tolerance
ServeurpersoCom Jan 23, 2026
25a0fd8
fix(cuda): disable weight cache in mmap mode to fix stale weight bug
Jan 23, 2026
43fde85
fix(cli): silence GCC truncation warnings in flux_cli.c
Jan 24, 2026
064a489
fix(cuda): auto-detect GPU architecture, fix TF32 for Turing
ServeurpersoCom Jan 24, 2026
9f31eaa
test script
ServeurpersoCom Jan 24, 2026
a3f8fee
refactor(cuda): remove dead code, optimize small buffer allocations
ServeurpersoCom Jan 24, 2026
8220611
feat(cuda): Add BF16 weight caching for single blocks
ServeurpersoCom Jan 24, 2026
2fb7202
perf(cuda): Add chained single block path to reduce CPU/GPU transfers
ServeurpersoCom Jan 24, 2026
ed57964
fix: Don't free bf16 mmap pointers in transformer cleanup
ServeurpersoCom Jan 24, 2026
03e7c5f
perf(cuda): Pre-compute AdaLN modulation once for all single blocks
ServeurpersoCom Jan 24, 2026
6f11845
cuda: use cublasGemmEx with CUBLAS_COMPUTE_32F_FAST_16F
ServeurpersoCom Jan 24, 2026
7ab2f3c
fix(cuda): disable BF16 weight cache in no-mmap mode
ServeurpersoCom Jan 24, 2026
233dbc1
nit
ServeurpersoCom Jan 24, 2026
f90270c
remove dead code
ServeurpersoCom Jan 24, 2026
95bb5de
doc
ServeurpersoCom Jan 25, 2026
a3b2492
feat(windows): add Windows build support with secure temp file handling
ServeurpersoCom Feb 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 230 additions & 0 deletions CUDA_IMPLEMENTATION.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
# CUDA Implementation Notes

This document describes the CUDA acceleration layer for flux2.c.

---

## Files Modified for CUDA

| File | Changes |
|------|---------|
| `flux_cuda.cu` | Main CUDA implementation (kernels, cuBLAS, tensor pool) |
| `flux_cuda.h` | Public API declarations |
| `flux_transformer.c` | CUDA paths for double/single blocks, BF16 weight loading |
| `flux_vae.c` | CUDA conv2d for VAE decoder |
| `flux_qwen3.c` | CUDA causal attention for text encoder |
| `Makefile` | `make cuda` target with nvcc compilation |

---

## Architecture Overview

### GPU Acceleration Strategy

1. **Weights stay on GPU** - BF16 weights are uploaded once and cached
2. **Activations in tensor pool** - Reusable GPU buffers avoid malloc/free
3. **Minimal CPU↔GPU transfers** - Only upload inputs, download outputs
4. **cuBLAS for matmuls** - Uses tensor cores via `cublasGemmEx`

### Key Data Structures

```c
// Tensor pool - reusable GPU buffers
g_tensor_pool[64] // Pool of GPU allocations
flux_cuda_tensor_get(size) // Acquire buffer
flux_cuda_tensor_release(id) // Release buffer

// Weight cache - permanent GPU storage for weights
g_weight_cache[2048] // CPU ptr → GPU ptr mapping
weight_cache_get() // Lookup cached weight
weight_cache_add() // Upload and cache weight
```

---

## CUDA Kernels

### Transformer Operations

| Kernel | Purpose |
|--------|---------|
| `k_silu` | SiLU activation |
| `k_silu_mul` | Fused SiLU + elementwise multiply (SwiGLU) |
| `k_mul` | Elementwise multiply |
| `k_gated_add` | Gated residual: `out += gate * x` |
| `k_split_fused` | Split fused QKV+MLP projection |
| `k_concat` | Concatenate attention + MLP outputs |
| `k_rms_norm` | RMSNorm |
| `k_qk_rms_norm` | Fused Q/K normalization |
| `k_adaln_norm` | AdaLN modulation |
| `k_softmax` | Row-wise softmax |
| `k_softmax_attention` | Fused attention softmax with scale |

### RoPE Kernels

| Kernel | Purpose |
|--------|---------|
| `k_rope_2d` | 2D RoPE for transformer (4 axes) |
| `k_rope_2d_offset` | RoPE with sequence offset |

### VAE Kernels

| Kernel | Purpose |
|--------|---------|
| `k_im2col` | im2col for conv2d |
| `k_add_bias_conv` | Add bias after convolution |

### Text Encoder Kernels

| Kernel | Purpose |
|--------|---------|
| `k_causal_softmax` | Causal attention with mask |
| `k_bf16_to_f32` | BF16→F32 conversion on GPU |

### Utility Kernels

| Kernel | Purpose |
|--------|---------|
| `k_transpose_shd_to_hsd` | Transpose [seq,heads,dim] → [heads,seq,dim] |
| `k_transpose_hsd_to_shd` | Transpose [heads,seq,dim] → [seq,heads,dim] |

---

## BF16 Weight Handling

### mmap Mode
- Weights read directly from mmap'd safetensors file as BF16
- Pointers are stable (point into mmap region)
- Weight cache **enabled** - weights uploaded once, cached permanently
- `g_weight_cache_disabled = 0`

### no-mmap Mode
- Weights copied via `safetensors_get_bf16()` into malloc'd buffers
- Pointers may be reused after free
- Weight cache **disabled** - weights uploaded fresh each time
- `g_weight_cache_disabled = 1`

### BF16→F32 Conversion
```c
flux_cuda_sgemm_gpu_bf16() // For mmap with cache
// 1. Check cache for existing F32 conversion
// 2. If miss: upload BF16, convert to F32 on GPU, cache result
// 3. Run cuBLAS sgemm with F32 weights
```

---

## Transformer Forward Paths

### Double Blocks (`double_block_forward_cuda`)
1. Upload img/txt hidden states to GPU
2. AdaLN modulation (fused for all streams)
3. QKV projection via `flux_cuda_sgemm_gpu_bf16`
4. Q/K normalization + RoPE
5. Joint attention via `flux_cuda_joint_attention_t`
6. Output projection + gated residual
7. MLP (SwiGLU) + gated residual
8. Download results to CPU

### Single Blocks (`single_block_forward_cuda_chained`)
- **Chained execution** - hidden state stays on GPU across all 20 blocks
- AdaLN vectors pre-computed once for all blocks
- Only final result downloaded to CPU

---

## Attention Implementation

### Joint Attention (Double Blocks)
```
img_out = softmax(img_Q @ cat(img_K, txt_K)^T) @ cat(img_V, txt_V)
txt_out = softmax(txt_Q @ cat(img_K, txt_K)^T) @ cat(img_V, txt_V)
```
- Uses `flux_cuda_joint_attention_t`
- Batched cuBLAS gemm for Q@K^T and scores@V

### Causal Attention (Qwen3 Text Encoder)
- GQA with 32 query heads, 8 KV heads (4:1 ratio)
- Causal mask + attention mask
- Uses `flux_cuda_causal_attention`

---

## Performance Characteristics

### Typical 1024×1024 @ 4 steps (RTX PRO 6000 Blackwell)

| Phase | Time | Notes |
|-------|------|-------|
| Text encoding | ~3s | Qwen3 36 layers, CUDA attention |
| Denoising | ~7s | 5 double + 20 single blocks |
| VAE decode | ~3.5s | CUDA conv2d |
| **Total** | ~14s | |

### Memory Usage
- Transformer weights: ~8GB (BF16)
- Qwen3 weights: ~8GB (F32, loaded per-layer in mmap mode)
- Activations: ~2GB peak
- Weight cache: Grows to ~4GB for transformer

---

## Build Instructions

```bash
# Build with CUDA support
make cuda

# Requirements:
# - CUDA toolkit (nvcc)
# - cuBLAS
# - OpenBLAS (for CPU fallback)

# GPU architecture auto-detected, or override:
make cuda CUDA_ARCH=sm_89 # Ada
make cuda CUDA_ARCH=sm_120 # Blackwell
```

---

## Debugging

### Enable verbose output
```c
// In flux_cuda.cu, uncomment:
// #define CUDA_DEBUG
```

### Check GPU memory
```bash
nvidia-smi --query-gpu=memory.used,memory.free --format=csv -l 1
```

### Verify correctness
```bash
# Generate with CPU reference
./flux_cpu -d model -p "test" -o ref.png --seed 42

# Generate with CUDA
./flux -d model -p "test" -o cuda.png --seed 42

# Compare (should be nearly identical, small FP differences OK)
```

---

## Known Limitations

1. **No Flash Attention** - Using standard cuBLAS attention
2. **No FP16 compute** - All compute in FP32 (weights can be BF16)
3. **Single GPU only** - No multi-GPU support
4. **No dynamic batching** - Single image at a time

---

## Future Optimizations

- [ ] Flash Attention 2 integration
- [ ] FP16 compute path for Ampere+
- [ ] Persistent kernel for single blocks
- [ ] CUDA graphs for reduced launch overhead
67 changes: 64 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ LIB = libflux.a
# Debug build flags
DEBUG_CFLAGS = -Wall -Wextra -g -O0 -DDEBUG -fsanitize=address

.PHONY: all clean debug lib install info test pngtest help generic blas mps
.PHONY: all clean debug lib install info test pngtest help generic blas mps cuda

# Default: show available targets
all: help
Expand All @@ -36,6 +36,8 @@ ifeq ($(UNAME_S),Darwin)
ifeq ($(UNAME_M),arm64)
@echo " make mps - Apple Silicon with Metal GPU (fastest)"
endif
else
@echo " make cuda - NVIDIA GPU with CUDA/cuBLAS (fastest)"
endif
@echo ""
@echo "Other targets:"
Expand All @@ -45,7 +47,11 @@ endif
@echo " make info - Show build configuration"
@echo " make lib - Build static library"
@echo ""
ifeq ($(UNAME_S),Darwin)
@echo "Example: make mps && ./flux -d flux-klein-model -p \"a cat\" -o cat.png"
else
@echo "Example: make cuda && ./flux -d flux-klein-model -p \"a cat\" -o cat.png"
endif

# =============================================================================
# Backend: generic (pure C, no BLAS)
Expand Down Expand Up @@ -107,6 +113,57 @@ mps:
@exit 1
endif

# =============================================================================
# Backend: cuda (NVIDIA GPU with CUDA/cuBLAS)
# =============================================================================
# CUDA Toolkit paths - adjust if needed
CUDA_PATH ?= /usr/local/cuda
NVCC = $(CUDA_PATH)/bin/nvcc

# Detect CUDA availability
CUDA_AVAILABLE := $(shell which $(NVCC) 2>/dev/null)

ifdef CUDA_AVAILABLE
CUDA_CFLAGS = $(CFLAGS_BASE) -DUSE_CUDA -DUSE_BLAS -I$(CUDA_PATH)/include
CUDA_NVCCFLAGS = -O3 -use_fast_math --compiler-options "$(CFLAGS_BASE)"
CUDA_LDFLAGS = $(LDFLAGS) -L$(CUDA_PATH)/lib64 -lcudart -lcublas -lopenblas -lstdc++

# Auto-detect GPU architecture from installed GPU, fallback to multi-arch fat binary
DETECTED_COMPUTE := $(shell nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null | head -1 | tr -d '.')
ifneq ($(DETECTED_COMPUTE),)
CUDA_ARCH ?= sm_$(DETECTED_COMPUTE)
CUDA_NVCCFLAGS += -arch=$(CUDA_ARCH)
else
# Fat binary: Turing (RTX 2080), Ampere (RTX 3090), Ada (RTX 4090), Hopper (H100), Blackwell (RTX 5090)
CUDA_NVCCFLAGS += -gencode arch=compute_75,code=sm_75 \
-gencode arch=compute_86,code=sm_86 \
-gencode arch=compute_89,code=sm_89 \
-gencode arch=compute_90,code=sm_90 \
-gencode arch=compute_120,code=sm_120
endif

cuda: clean cuda-build
@echo ""
@echo "Built with CUDA backend (NVIDIA GPU acceleration)"
@echo "Using GPU architecture: $(CUDA_ARCH)"

cuda-build: $(SRCS:.c=.cuda.o) $(CLI_SRCS:.c=.cuda.o) flux_cuda.o main.cuda.o
$(CC) $(CUDA_CFLAGS) -o $(TARGET) $^ $(CUDA_LDFLAGS)

%.cuda.o: %.c flux.h flux_kernels.h
$(CC) $(CUDA_CFLAGS) -c -o $@ $<

flux_cuda.o: flux_cuda.cu flux_cuda.h
$(NVCC) $(CUDA_NVCCFLAGS) -c -o $@ $<

else
cuda:
@echo "Error: CUDA toolkit not found"
@echo "Please install CUDA toolkit and ensure nvcc is in PATH"
@echo "Or set CUDA_PATH environment variable"
@exit 1
endif

# =============================================================================
# Build rules
# =============================================================================
Expand Down Expand Up @@ -153,7 +210,7 @@ install: $(TARGET) $(LIB)
install -m 644 flux_kernels.h /usr/local/include/

clean:
rm -f $(OBJS) $(CLI_OBJS) *.mps.o flux_metal.o main.o $(TARGET) $(LIB)
rm -f $(OBJS) $(CLI_OBJS) *.mps.o *.cuda.o flux_metal.o flux_cuda.o main.o $(TARGET) $(LIB)
rm -f flux_shaders_source.h

info:
Expand All @@ -169,13 +226,16 @@ ifeq ($(UNAME_M),arm64)
endif
else
@echo " blas - OpenBLAS (requires libopenblas-dev)"
ifdef CUDA_AVAILABLE
@echo " cuda - NVIDIA GPU (requires CUDA toolkit)"
endif
endif

# =============================================================================
# Dependencies
# =============================================================================
flux.o: flux.c flux.h flux_kernels.h flux_safetensors.h flux_qwen3.h
flux_kernels.o: flux_kernels.c flux_kernels.h
flux_kernels.o: flux_kernels.c flux_kernels.h flux_cuda.h
flux_tokenizer.o: flux_tokenizer.c flux.h
flux_vae.o: flux_vae.c flux.h flux_kernels.h
flux_transformer.o: flux_transformer.c flux.h flux_kernels.h
Expand All @@ -188,4 +248,5 @@ terminals.o: terminals.c terminals.h flux.h
flux_cli.o: flux_cli.c flux_cli.h flux.h flux_qwen3.h embcache.h linenoise.h terminals.h
linenoise.o: linenoise.c linenoise.h
embcache.o: embcache.c embcache.h
flux_cuda.o: flux_cuda.cu flux_cuda.h
main.o: main.c flux.h flux_kernels.h flux_cli.h terminals.h
Loading