From db66f4b7eb31760ec35a2408bbec3ed71c7c9c9a Mon Sep 17 00:00:00 2001 From: Todd Fisher Date: Sat, 7 Feb 2026 21:23:19 -0500 Subject: [PATCH 1/2] Add cuda support --- Makefile | 35 +- README.md | 203 ++-- flux.c | 126 ++- flux_cuda.cu | 1506 ++++++++++++++++++++++++++++++ flux_cuda.h | 128 +++ flux_kernels.c | 898 +++++++++++++++++- flux_kernels.h | 16 + flux_qwen3.c | 66 +- flux_transformer.c | 2186 ++++++++++++++++++++++++++++---------------- flux_vae.c | 559 ++++++++++- main.c | 215 +++-- 11 files changed, 4850 insertions(+), 1088 deletions(-) create mode 100644 flux_cuda.cu create mode 100644 flux_cuda.h diff --git a/Makefile b/Makefile index a516d03..3f73a9f 100644 --- a/Makefile +++ b/Makefile @@ -21,7 +21,7 @@ LIB = libflux.a # Debug build flags DEBUG_CFLAGS = -Wall -Wextra -g -O0 -DDEBUG -fsanitize=address -.PHONY: all clean debug lib install info test pngtest help generic blas mps +.PHONY: all clean debug lib install info test pngtest help generic blas cuda mps # Default: show available targets all: help @@ -32,6 +32,9 @@ help: @echo "Choose a backend:" @echo " make generic - Pure C, no dependencies (slow)" @echo " make blas - With BLAS acceleration (~30x faster)" +ifeq ($(UNAME_S),Linux) + @echo " make cuda - NVIDIA CUDA + cuBLAS (GPU acceleration)" +endif ifeq ($(UNAME_S),Darwin) ifeq ($(UNAME_M),arm64) @echo " make mps - Apple Silicon with Metal GPU (fastest)" @@ -70,6 +73,25 @@ blas: clean $(TARGET) @echo "" @echo "Built with BLAS backend (~30x faster than generic)" +# ============================================================================= +# Backend: cuda (NVIDIA CUDA + cuBLAS, Linux) +# ============================================================================= +ifeq ($(UNAME_S),Linux) +CUDA_HOME ?= /usr/local/cuda +CUDA_CFLAGS = $(CFLAGS_BASE) -DUSE_BLAS -DUSE_OPENBLAS -DUSE_CUDA -I/usr/include/openblas -I$(CUDA_HOME)/include +CUDA_LDFLAGS = -L$(CUDA_HOME)/lib64 -Wl,-rpath,$(CUDA_HOME)/lib64 -lcublasLt -lcublas -lcudart -lopenblas -lstdc++ -lm + +cuda: CFLAGS = $(CUDA_CFLAGS) +cuda: LDFLAGS = $(CUDA_LDFLAGS) +cuda: clean cuda-build + @echo "" + @echo "Built with CUDA backend (cuBLAS GPU acceleration)" +else +cuda: + @echo "Error: CUDA backend requires Linux with NVIDIA CUDA toolkit" + @exit 1 +endif + # ============================================================================= # Backend: mps (Apple Silicon Metal GPU) # ============================================================================= @@ -113,6 +135,9 @@ endif $(TARGET): $(OBJS) $(CLI_OBJS) main.o $(CC) $(CFLAGS) -o $@ $^ $(LDFLAGS) +cuda-build: $(OBJS) $(CLI_OBJS) main.o flux_cuda.o + $(CC) $(CUDA_CFLAGS) -o $(TARGET) $^ $(CUDA_LDFLAGS) + lib: $(LIB) $(LIB): $(OBJS) @@ -153,9 +178,12 @@ install: $(TARGET) $(LIB) install -m 644 flux_kernels.h /usr/local/include/ clean: - rm -f $(OBJS) $(CLI_OBJS) *.mps.o flux_metal.o main.o $(TARGET) $(LIB) + rm -f $(OBJS) $(CLI_OBJS) *.mps.o flux_metal.o flux_cuda.o main.o $(TARGET) $(LIB) rm -f flux_shaders_source.h +flux_cuda.o: flux_cuda.cu flux_cuda.h + nvcc -O3 -U_GNU_SOURCE -c -o $@ $< + info: @echo "Platform: $(UNAME_S) $(UNAME_M)" @echo "Compiler: $(CC)" @@ -169,6 +197,7 @@ ifeq ($(UNAME_M),arm64) endif else @echo " blas - OpenBLAS (requires libopenblas-dev)" + @echo " cuda - NVIDIA CUDA + cuBLAS (requires CUDA toolkit + OpenBLAS)" endif # ============================================================================= @@ -183,9 +212,11 @@ flux_sample.o: flux_sample.c flux.h flux_kernels.h flux_image.o: flux_image.c flux.h flux_safetensors.o: flux_safetensors.c flux_safetensors.h flux_qwen3.o: flux_qwen3.c flux_qwen3.h flux_safetensors.h +flux_qwen3.o: flux_cuda.h flux_qwen3_tokenizer.o: flux_qwen3_tokenizer.c flux_qwen3.h terminals.o: terminals.c terminals.h flux.h flux_cli.o: flux_cli.c flux_cli.h flux.h flux_qwen3.h embcache.h linenoise.h terminals.h linenoise.o: linenoise.c linenoise.h embcache.o: embcache.c embcache.h main.o: main.c flux.h flux_kernels.h flux_cli.h terminals.h +flux_transformer.o: flux_cuda.h diff --git a/README.md b/README.md index 2353621..9530201 100644 --- a/README.md +++ b/README.md @@ -1,75 +1,64 @@ -# FLUX.2-klein Pure C Implementation +# FLUX.2-klein-4B Pure C Implementation -This program generates images from text prompts (and optionally from other images) using the [FLUX.2-klein models](https://bfl.ai/models/flux-2-klein) from [Black Forest Labs](https://bfl.ai/). It can be used as a library as well, and is implemented entirely in C, with zero external dependencies beyond the C standard library. MPS and BLAS acceleration are optional but recommended. +This program generates images from text prompts (and optionally from other images) using the [FLUX.2-klein-4B model](https://bfl.ai/models/flux-2-klein) from [Black Forest Labs](https://bfl.ai/). It can be used as a library as well, and is implemented entirely in C, with zero external dependencies beyond the C standard library in `make generic` mode. MPS, CUDA, and BLAS acceleration are optional. Supported models: -- **Flux.2 Klein 4B distilled** (4 steps, auto guidance set to 1, very fast). -- **Flux.2 Klein 4B base** (50 steps for max quality, or less. Classifier-Free Diffusion Guidance, much slower but more generation variety). -- **Flux.2 Klein 9B distilled** (4 steps, larger model, higher quality. Non-commercial license). -- **Flux.2 Klein 9B base** (50 steps, CFG, highest quality. Non-commercial license). +- Flux.2 4B Klein distilled model (4 steps, auto guidance set to 1, very fast). +- Flux.2 4B Klein base model. (50 steps for max quality, or less. Classifier-Free Diffusion Guidance, much slower but more generation variety). ## Quick Start ```bash # Build (choose your backend) make mps # Apple Silicon (fastest) +# or: make cuda # Linux + NVIDIA GPU (cuBLAS) # or: make blas # Intel Mac / Linux with OpenBLAS # or: make generic # Pure C, no dependencies # Download the model (~16GB) - pick one: -./download_model.sh 4b # using curl -# or: pip install huggingface_hub && python download_model.py 4b +./download_model.sh # using curl +# or: pip install huggingface_hub && python download_model.py # Generate an image -./flux -d flux-klein-4b -p "A woman wearing sunglasses" -o output.png +./flux -d flux-klein-model -p "A woman wearing sunglasses" -o output.png ``` If you want to try the base model, instead of the distilled one (much slower, higher quality), use the following instructions. Use 10 steps if your computer is quite slow, instead of the default of 50, it will still work well enough to test it (10 seconds to generate a 256x256 image on a MacBook M3 Max). ``` -./download_model.sh 4b-base -# or: pip install huggingface_hub && python download_model.py 4b-base -./flux -d flux-klein-4b-base -p "A woman wearing sunglasses" -o output.png +./download_model.sh --base +# or: pip install huggingface_hub && python download_model.py --base +./flux -d flux-klein-base-model -p "A woman wearing sunglasses" -o output.png ``` -If you want to try the 9B model (higher quality, non-commercial license, ~30GB download): -```bash -# 9B is a gated model - you need a HuggingFace token -# 1. Accept the license at https://huggingface.co/black-forest-labs/FLUX.2-klein-9B -# 2. Get your token from https://huggingface.co/settings/tokens -./download_model.sh 9b --token YOUR_TOKEN -# or: python download_model.py 9b --token YOUR_TOKEN -# or: set HF_TOKEN env var -./flux -d flux-klein-9b -p "A woman wearing sunglasses" -o output.png -``` - -That's it. No Python runtime or CUDA toolkit required at inference time. +That's it. No Python runtime required at inference time. CUDA is optional and only needed when building with `make cuda`. ## Example Output ![Woman with sunglasses](images/woman_with_sunglasses.png) -*Generated with: `./flux -d flux-klein-4b -p "A picture of a woman in 1960 America. Sunglasses. ASA 400 film. Black and White." -W 512 -H 512 -o woman.png`* +*Generated with: `./flux -d flux-klein-model -p "A picture of a woman in 1960 America. Sunglasses. ASA 400 film. Black and White." -W 512 -H 512 -o woman.png`* ### Image-to-Image Example ![antirez to drawing](images/antirez_to_drawing.png) -*Generated with: `./flux -i antirez.png -o antirez_to_drawing.png -p "make it a drawing" -d flux-klein-4b`* +*Generated with: `./flux -i antirez.png -o antirez_to_drawing.png -p "make it a drawing" -d flux-klein-model`* ## Features -- **Zero dependencies**: Pure C implementation, works standalone. BLAS optional for ~30x speedup (Apple Accelerate on macOS, OpenBLAS on Linux) +- **Zero dependencies in generic mode**: Pure C implementation works standalone. Optional BLAS/CUDA backends provide significant speedups. - **Metal GPU acceleration**: Automatic on Apple Silicon Macs. Performance matches PyTorch's optimized MPS pipeline +- **CUDA GPU acceleration**: Optional Linux backend via cuBLAS (`make cuda`) for NVIDIA GPUs - **Runs where Python can't**: Memory-mapped weights (default) enable inference on 8GB RAM systems where the Python ML stack cannot run FLUX.2 at all - **Text-to-image**: Generate images from text prompts - **Image-to-image**: Transform existing images guided by prompts - **Multi-reference**: Combine multiple reference images (e.g., `-i car.png -i beach.png` for "car on beach") -- **Integrated text encoder**: Qwen3 encoder built-in (4B or 8B depending on model), no external embedding computation needed -- **Memory efficient**: Automatic encoder release after encoding (up to ~16GB freed) -- **Memory-mapped weights**: Enabled by default. Reduces peak memory from ~16GB to ~4-5GB. Fastest mode on MPS; BLAS users with plenty of RAM may prefer `--no-mmap` for faster inference +- **Integrated text encoder**: Qwen3-4B encoder built-in, no external embedding computation needed +- **Memory efficient**: Automatic encoder release after encoding (~8GB freed) +- **Memory-mapped weights**: Enabled by default. Reduces peak memory from ~16GB to ~4-5GB. Fastest mode on MPS; BLAS/CUDA users with plenty of RAM may prefer `--no-mmap` for faster inference - **Size-independent seeds**: Same seed produces similar compositions at different resolutions. Explore at 256×256, then render at 512×512 with the same seed -- **Terminal image display**: watch the resulting image without leaving your terminal (Ghostty, Kitty, iTerm2, WezTerm, or Konsole). +- **Terminal image display**: watch the resulting image without leaving your terminal (Ghostty, Kitty, iTerm2, or Konsole). ### Terminal Image Display @@ -78,14 +67,14 @@ That's it. No Python runtime or CUDA toolkit required at inference time. Display generated images directly in your terminal with `--show`, or watch the denoising process step-by-step with `--show-steps`: ```bash -# Display final image in terminal (auto-detects Kitty/Ghostty/iTerm2/WezTerm/Konsole) -./flux -d flux-klein-4b -p "a cute robot" -o robot.png --show +# Display final image in terminal (auto-detects Kitty/Ghostty/iTerm2/Konsole) +./flux -d flux-klein-model -p "a cute robot" -o robot.png --show # Display each denoising step (slower, but interesting to watch) -./flux -d flux-klein-4b -p "a cute robot" -o robot.png --show-steps +./flux -d flux-klein-model -p "a cute robot" -o robot.png --show-steps ``` -Requires a terminal supporting the [Kitty graphics protocol](https://sw.kovidgoyal.net/kitty/graphics-protocol/) (such as [Kitty](https://sw.kovidgoyal.net/kitty/) or [Ghostty](https://ghostty.org/)), the iTerm2 inline image protocol ([iTerm2](https://iterm2.com/), [WezTerm](https://wezfurlong.org/wezterm/)), or [Konsole](https://konsole.kde.org/). Terminal type is auto-detected from environment variables. +Requires a terminal supporting the [Kitty graphics protocol](https://sw.kovidgoyal.net/kitty/graphics-protocol/) (such as [Kitty](https://sw.kovidgoyal.net/kitty/) or [Ghostty](https://ghostty.org/)), [iTerm2](https://iterm2.com/), or [Konsole](https://konsole.kde.org/). Terminal type is auto-detected from environment variables. Use `--zoom N` to adjust the display size (default: 2 for Retina displays, use 1 for non-HiDPI screens). @@ -94,7 +83,7 @@ Use `--zoom N` to adjust the display size (default: 2 for Retina displays, use 1 ### Text-to-Image ```bash -./flux -d flux-klein-4b -p "A fluffy orange cat sitting on a windowsill" -o cat.png +./flux -d flux-klein-model -p "A fluffy orange cat sitting on a windowsill" -o cat.png ``` ### Image-to-Image @@ -102,7 +91,7 @@ Use `--zoom N` to adjust the display size (default: 2 for Retina displays, use 1 Transform an existing image based on a prompt: ```bash -./flux -d flux-klein-4b -p "oil painting style" -i photo.png -o painting.png +./flux -d flux-klein-model -p "oil painting style" -i photo.png -o painting.png ``` FLUX.2 uses **in-context conditioning** for image-to-image generation. Unlike traditional approaches that add noise to the input image, FLUX.2 passes the reference image as additional tokens that the model can attend to during generation. This means: @@ -119,7 +108,7 @@ FLUX.2 uses **in-context conditioning** for image-to-image generation. Unlike tr **Super Resolution:** Since the reference image can be a different size than the output, you can use img2img for upscaling: ```bash -./flux -d flux-klein-4b -i small.png -W 1024 -H 1024 -o big.png -p "Create an exact copy of the input image." +./flux -d flux-klein-model -i small.png -W 1024 -H 1024 -o big.png -p "Create an exact copy of the input image." ``` The model will generate a higher-resolution version while preserving the composition and details of the input. @@ -129,7 +118,7 @@ The model will generate a higher-resolution version while preserving the composi Combine elements from multiple reference images: ```bash -./flux -d flux-klein-4b -i car.png -i beach.png -p "a sports car on the beach" -o result.png +./flux -d flux-klein-model -i car.png -i beach.png -p "a sports car on the beach" -o result.png ``` Each reference image is encoded separately and passed to the transformer with different positional embeddings (T=10, T=20, T=30, ...). The model attends to all references during generation, allowing it to combine elements from each. @@ -147,7 +136,7 @@ You can specify up to 16 reference images with multiple `-i` flags. The prompt g Start without `-p` to enter interactive mode: ```bash -./flux -d flux-klein-4b +./flux -d flux-klein-model ``` Generate images by typing prompts. Each image gets a `$N` reference ID: @@ -204,7 +193,7 @@ Done -> /tmp/flux-.../image-0003.png (ref $2) ``` -q, --quiet Silent mode, no output -v, --verbose Show detailed config and timing info - --show Display image in terminal (auto-detects Kitty/Ghostty/iTerm2/WezTerm/Konsole) + --show Display image in terminal (auto-detects Kitty/Ghostty/iTerm2/Konsole) --show-steps Display each denoising step (slower) --zoom N Terminal image zoom factor (default: 2 for Retina) ``` @@ -213,7 +202,6 @@ Done -> /tmp/flux-.../image-0003.png (ref $2) ``` -m, --mmap Memory-mapped weights (default, fastest on MPS) --no-mmap Disable mmap, load all weights upfront - --no-license-info Suppress non-commercial license warning (9B model) -e, --embeddings PATH Load pre-computed text embeddings (advanced) -h, --help Show help ``` @@ -222,7 +210,7 @@ Done -> /tmp/flux-.../image-0003.png (ref $2) The seed is always printed to stderr, even when random: ``` -$ ./flux -d flux-klein-4b -p "a landscape" -o out.png +$ ./flux -d flux-klein-model -p "a landscape" -o out.png Seed: 1705612345 ... Saving... out.png 256x256 (0.1s) @@ -230,7 +218,7 @@ Saving... out.png 256x256 (0.1s) To reproduce the same image, use the printed seed: ``` -$ ./flux -d flux-klein-4b -p "a landscape" -o out.png -S 1705612345 +$ ./flux -d flux-klein-model -p "a landscape" -o out.png -S 1705612345 ``` ## PNG Metadata @@ -250,7 +238,7 @@ identify -verbose image.png | grep -A1 "Properties:" The following metadata fields are stored: - `flux:seed` - The random seed used for generation -- `flux:model` - The model name (e.g., FLUX.2-klein-4B, FLUX.2-klein-9B) +- `flux:model` - The model name (FLUX.2-klein-4B) - `Software` - Program identifier ## Building @@ -261,12 +249,14 @@ Choose a backend when building: make # Show available backends make generic # Pure C, no dependencies (slow) make blas # BLAS acceleration (~30x faster) +make cuda # CUDA + cuBLAS GPU acceleration (Linux + NVIDIA) make mps # Apple Silicon Metal GPU (fastest, macOS only) ``` **Recommended:** - macOS Apple Silicon: `make mps` - macOS Intel: `make blas` +- Linux + NVIDIA GPU: `make cuda` - Linux with OpenBLAS: `make blas` - Linux without OpenBLAS: `make generic` @@ -279,6 +269,19 @@ sudo apt install libopenblas-dev sudo dnf install openblas-devel ``` +For `make cuda` on Linux, install: +- NVIDIA driver +- CUDA toolkit (for `cublas` and `cudart`) +- OpenBLAS development headers (`libopenblas-dev` or `openblas-devel`) + +CUDA tuning: +- `FLUX_CUDA_MIN_OPS` controls the minimum GEMM op-count (`M*K*N`) dispatched to CUDA. +- Default is `2097152` (2M ops). +- Example: `FLUX_CUDA_MIN_OPS=2000000 ./flux -d flux-klein-model -p "a cat" -o out.png` +- `FLUX_CUDA_WEIGHT_CACHE_MB` controls GPU memory used to cache frequently reused linear weights. +- Default: auto (up to 1024MB, capped to 25% of free VRAM at startup). Set `0` to disable. +- Example: `FLUX_CUDA_WEIGHT_CACHE_MB=2048 ./flux --no-mmap -d flux-klein-model -p "a cat" -o out.png` + Other targets: ```bash make clean # Clean build artifacts @@ -314,35 +317,23 @@ python3 run_test.py --flux-binary ./flux --model-dir /path/to/model Download model weights from HuggingFace using one of these methods: -**4B Distilled model** (~16GB, fast 4-step inference): +**Distilled model** (~16GB, fast 4-step inference): ```bash -./download_model.sh 4b # using curl -# or: python download_model.py 4b # using huggingface_hub +./download_model.sh # using curl +# or: python download_model.py # using huggingface_hub ``` -**4B Base model** (~16GB, 50-step inference with CFG, higher quality): +**Base model** (~16GB, 50-step inference with CFG, higher quality): ```bash -./download_model.sh 4b-base -# or: python download_model.py 4b-base +./download_model.sh --base +# or: python download_model.py --base ``` -**9B models** (~30GB, higher quality, non-commercial license): -```bash -# 9B models are gated - require HuggingFace authentication -# 1. Accept the license at https://huggingface.co/black-forest-labs/FLUX.2-klein-9B -# 2. Get a token from https://huggingface.co/settings/tokens -./download_model.sh 9b --token YOUR_TOKEN # distilled -./download_model.sh 9b-base --token YOUR_TOKEN # base (CFG, highest quality) -# or: python download_model.py 9b --token YOUR_TOKEN -# You can also set the HF_TOKEN environment variable -``` - -| Model | Directory | Size | Components | -|-------|-----------|------|------------| -| 4B distilled | `./flux-klein-4b` | ~16GB | VAE (~300MB), Transformer (~4GB), Qwen3-4B (~8GB) | -| 4B base | `./flux-klein-4b-base` | ~16GB | VAE (~300MB), Transformer (~4GB), Qwen3-4B (~8GB) | -| 9B distilled | `./flux-klein-9b` | ~30GB | VAE (~300MB), Transformer (~17GB), Qwen3-8B (~15GB) | -| 9B base | `./flux-klein-9b-base` | ~30GB | VAE (~300MB), Transformer (~17GB), Qwen3-8B (~15GB) | +The distilled model downloads to `./flux-klein-model`, the base model to `./flux-klein-base-model`. Both contain: +- VAE (~300MB) +- Transformer (~4GB) +- Qwen3-4B Text Encoder (~8GB) +- Tokenizer ## How Fast Is It? @@ -363,20 +354,6 @@ The MPS implementation is faster than the PyTorch optimized pipeline at all reso - The `make generic` backend (pure C, no BLAS) is approximately 30x slower than BLAS and not included in benchmarks. - The fastest implementation for Metal remains [the Draw Things app](https://drawthings.ai/) that can produce a 1024x1024 image in just 14.23 seconds (in the same hardware), however it is worth noting that it uses 6-bit quantized weights, while this implementation uses the official BF16 weights. The 6-bit quantization used by Draw Things provides both a big memory win and a moderate speed advantage (not nearly as much as it could in an LLM, where causal attention is dominated by memory bandwidth); if we account for this, the performance is comparable. -### Community Benchmarks - -The following timings for 512x512 generation (distilled model, 4 steps) were reported by users of Flux2.c. They can serve as a rough indication of the performance you could expect, but results vary widely depending on the hardware, Metal availability (the code is heavily optimized for Apple Silicon via MPS), and whether BLAS acceleration is used on CPU. - -| Hardware | Backend | 512x512 | -|----------|---------|---------| -| M3 Ultra | MPS | 4.5s | -| M3 Max | MPS | 7.6s | -| MacBook Pro M4 | MPS | 19s | -| MacBook Pro M1 Max | MPS | 39.9s | -| Apple M1 Pro | MPS | 42.4s | -| AMD Ryzen 7800X3D | BLAS | 47.8s | -| Intel i5-1135G7 | BLAS | 218s | - ## Resolution Limits **Maximum resolution**: 1792x1792 pixels. The model produces good results up to this size; beyond this resolution image quality degrades significantly (this is a model limitation, not an implementation issue). @@ -387,21 +364,15 @@ Dimensions should be multiples of 16 (the VAE downsampling factor). ## Model Architecture -All models share the same rectified flow transformer architecture, differing only in dimensions: +Both models share the same architecture, a rectified flow transformer: -| Component | 4B | 9B | -|-----------|-----|-----| -| Transformer hidden | 3072 | 4096 | -| Attention heads | 24 | 32 | -| Head dim | 128 | 128 | -| Double blocks | 5 | 8 | -| Single blocks | 20 | 24 | -| Text Encoder | Qwen3-4B (2560 hidden, 36 layers) | Qwen3-8B (4096 hidden, 36 layers) | -| VAE | AutoencoderKL, 128 latent channels, 8x spatial compression | Same | +| Component | Architecture | +|-----------|-------------| +| Transformer | 5 double blocks + 20 single blocks, 3072 hidden dim, 24 attention heads | +| VAE | AutoencoderKL, 128 latent channels, 8x spatial compression | +| Text Encoder | Qwen3-4B, 36 layers, 2560 hidden dim | -Architecture dimensions are read automatically from the model's config JSON files at load time. - -The distilled and base variants differ in inference: +The models differ in inference: | | Distilled | Base | |---|-----------|------| @@ -409,7 +380,7 @@ The distilled and base variants differ in inference: | CFG guidance | 1.0 (none) | 4.0 (default) | | Passes per step | 1 | 2 (conditioned + unconditioned) | -The model type (distilled vs base, 4B vs 9B) is autodetected from the model directory. Use `--base` to force base model mode if autodetection fails. +The model type is autodetected from `model_index.json` in the model directory. Use `--base` to force base model mode if autodetection fails. **Classifier-Free Guidance (CFG)**: The base model runs the transformer twice per step — once with an empty prompt (unconditioned) and once with the real prompt (conditioned). The final velocity is `v = v_uncond + guidance * (v_cond - v_uncond)`. This makes each step ~2x slower than the distilled model, and the base model needs ~12x more steps, making it roughly 25x slower overall. @@ -427,13 +398,13 @@ The `--power` flag provides a middle ground: a power curve schedule (`t = 1 - (i ```bash # Base model with 10 steps and linear schedule -./flux -d flux-klein-4b-base -p "a cat" -o cat.png -s 10 --linear +./flux -d flux-klein-base-model -p "a cat" -o cat.png -s 10 --linear # Base model with power schedule (quadratic by default) -./flux -d flux-klein-4b-base -p "a cat" -o cat.png -s 10 --power +./flux -d flux-klein-base-model -p "a cat" -o cat.png -s 10 --power # Power schedule with custom exponent -./flux -d flux-klein-4b-base -p "a cat" -o cat.png -s 10 --power-alpha 1.5 +./flux -d flux-klein-base-model -p "a cat" -o cat.png -s 10 --power-alpha 1.5 ``` In interactive CLI mode, toggle with `!linear` or `!power [alpha]`. @@ -444,8 +415,6 @@ If you have a terminal supporting the iTerm2 or Kitty terminal graphics protocol ## Memory Requirements -### 4B model - With mmap (default): | Phase | Memory | @@ -462,24 +431,6 @@ With `--no-mmap` (all weights in RAM): | Diffusion | ~8GB (transformer ~4GB + VAE ~300MB + activations) | | Peak | ~16GB (if encoder not released) | -### 9B model - -With mmap (default): - -| Phase | Memory | -|-------|--------| -| Text encoding | ~3-4GB (larger layers loaded on-demand) | -| Diffusion | ~2-3GB (more/larger blocks loaded on-demand) | -| Peak | ~8-10GB | - -With `--no-mmap` (all weights in RAM): - -| Phase | Memory | -|-------|--------| -| Text encoding | ~15GB (Qwen3-8B encoder weights) | -| Diffusion | ~17GB (transformer ~17GB + VAE ~300MB + activations) | -| Peak | ~32GB (if encoder not released) | - The text encoder is automatically released after encoding, reducing peak memory during diffusion. If you generate multiple images with different prompts, the encoder reloads automatically. ## Memory-Mapped Weights (Default) @@ -487,8 +438,8 @@ The text encoder is automatically released after encoding, reducing peak memory Memory-mapped weight loading is enabled by default. Use `--no-mmap` to disable and load all weights upfront. ```bash -./flux -d flux-klein-4b -p "A cat" -o cat.png # mmap (default) -./flux -d flux-klein-4b -p "A cat" -o cat.png --no-mmap # load all upfront +./flux -d flux-klein-model -p "A cat" -o cat.png # mmap (default) +./flux -d flux-klein-model -p "A cat" -o cat.png --no-mmap # load all upfront ``` **How it works:** Instead of loading all model weights into RAM upfront, mmap keeps the safetensors files memory-mapped and loads weights on-demand: @@ -504,6 +455,8 @@ This reduces peak memory from ~16GB to ~4-5GB, making inference possible on 16GB - **BLAS (CPU):** mmap is **slightly slower** but uses much less RAM. BLAS requires f32 weights, so each block must be converted from bf16→f32 on every step (25 blocks × 4 steps = 100 conversions). With `--no-mmap`, this conversion happens once at startup. **Recommendation:** If you have 32GB+ RAM and use BLAS, try `--no-mmap` for faster inference. If RAM is limited, mmap lets you run at all. +- **CUDA (Linux + NVIDIA):** mmap is usually **slower** for the same reason as BLAS (repeated bf16→f32 conversions and on-demand block loading in the denoiser loop). **Recommendation:** If you have enough host RAM, prefer `--no-mmap` for faster generation. + - **Generic (pure C):** Same tradeoffs as BLAS, but slower overall. ## C Library API @@ -520,7 +473,7 @@ Here's a complete program that generates an image from a text prompt: int main(void) { /* Load the model. This loads VAE, transformer, and text encoder. */ - flux_ctx *ctx = flux_load_dir("flux-klein-4b"); + flux_ctx *ctx = flux_load_dir("flux-klein-model"); if (!ctx) { fprintf(stderr, "Failed to load model: %s\n", flux_get_error()); return 1; @@ -566,7 +519,7 @@ Transform an existing image guided by a text prompt using in-context conditionin #include int main(void) { - flux_ctx *ctx = flux_load_dir("flux-klein-4b"); + flux_ctx *ctx = flux_load_dir("flux-klein-model"); if (!ctx) return 1; /* Load the input image */ @@ -606,7 +559,7 @@ int main(void) { When generating multiple images with different seeds but the same prompt, you can avoid reloading the text encoder: ```c -flux_ctx *ctx = flux_load_dir("flux-klein-4b"); +flux_ctx *ctx = flux_load_dir("flux-klein-model"); flux_params params = FLUX_PARAMS_DEFAULT; params.width = 256; params.height = 256; @@ -721,7 +674,7 @@ This saves to `/tmp/`: 4. Run C with the same inputs: ```bash -./flux -d flux-klein-4b --debug-py -W 256 -H 256 --steps 4 -o /tmp/c_debug.png +./flux -d flux-klein-model --debug-py -W 256 -H 256 --steps 4 -o /tmp/c_debug.png ``` 5. Compare the outputs visually or numerically. diff --git a/flux.c b/flux.c index 52865b0..351227a 100644 --- a/flux.c +++ b/flux.c @@ -14,6 +14,9 @@ #include #include #include +#if defined(__unix__) || defined(__APPLE__) +#include +#endif #ifdef USE_METAL #include "flux_metal.h" @@ -43,8 +46,8 @@ extern flux_image *flux_vae_decode(flux_vae_t *vae, const float *latent, extern float *flux_image_to_tensor(const flux_image *img); extern flux_transformer_t *flux_transformer_load(FILE *f); -extern flux_transformer_t *flux_transformer_load_safetensors(const char *model_dir); -extern flux_transformer_t *flux_transformer_load_safetensors_mmap(const char *model_dir); +extern flux_transformer_t *flux_transformer_load_safetensors(safetensors_file_t *sf); +extern flux_transformer_t *flux_transformer_load_safetensors_mmap(safetensors_file_t *sf); extern void flux_transformer_free(flux_transformer_t *tf); extern float *flux_transformer_forward(flux_transformer_t *tf, const float *img_latent, int img_h, int img_w, @@ -338,17 +341,27 @@ void flux_release_text_encoder(flux_ctx *ctx) { #endif } -/* Load transformer on-demand if not already loaded */ -static int flux_load_transformer_if_needed(flux_ctx *ctx) { +/* Load transformer on-demand if not already loaded. */ +static int flux_load_transformer_internal(flux_ctx *ctx, int emit_phase_callbacks) { if (ctx->transformer) return 1; /* Already loaded */ - if (flux_phase_callback) flux_phase_callback("Loading FLUX.2 transformer", 0); - if (ctx->use_mmap) { - ctx->transformer = flux_transformer_load_safetensors_mmap(ctx->model_dir); - } else { - ctx->transformer = flux_transformer_load_safetensors(ctx->model_dir); + char path[1024]; + snprintf(path, sizeof(path), "%s/transformer/diffusion_pytorch_model.safetensors", + ctx->model_dir); + + if (emit_phase_callbacks && flux_phase_callback) flux_phase_callback("Loading FLUX.2 transformer", 0); + safetensors_file_t *sf = safetensors_open(path); + if (sf) { + if (ctx->use_mmap) { + /* Mmap mode: load only small weights, keep sf open for on-demand loading. + * The transformer takes ownership of sf and will close it on free. */ + ctx->transformer = flux_transformer_load_safetensors_mmap(sf); + } else { + ctx->transformer = flux_transformer_load_safetensors(sf); + safetensors_close(sf); + } } - if (flux_phase_callback) flux_phase_callback("Loading FLUX.2 transformer", 1); + if (emit_phase_callbacks && flux_phase_callback) flux_phase_callback("Loading FLUX.2 transformer", 1); if (!ctx->transformer) { set_error("Failed to load transformer"); @@ -357,6 +370,59 @@ static int flux_load_transformer_if_needed(flux_ctx *ctx) { return 1; } +static int flux_load_transformer_if_needed(flux_ctx *ctx) { + return flux_load_transformer_internal(ctx, 1); +} + +#if defined(__unix__) || defined(__APPLE__) +typedef struct { + flux_ctx *ctx; + pthread_t thread; + int started; + int ok; +} flux_tf_preload_task_t; + +static int flux_overlap_preload_enabled(void) { + static int cached = -1; + if (cached == -1) { + cached = getenv("FLUX_OVERLAP_PRELOAD") ? 1 : 0; + } + return cached; +} + +static void *flux_transformer_preload_thread(void *arg) { + flux_tf_preload_task_t *task = (flux_tf_preload_task_t *)arg; + task->ok = flux_load_transformer_internal(task->ctx, 0); + return NULL; +} + +static void flux_transformer_preload_begin(flux_ctx *ctx, flux_tf_preload_task_t *task) { + if (!task) return; + memset(task, 0, sizeof(*task)); + task->ctx = ctx; + + if (!ctx || ctx->transformer) return; + if (!flux_overlap_preload_enabled()) return; + + if (pthread_create(&task->thread, NULL, flux_transformer_preload_thread, task) == 0) { + task->started = 1; + } +} + +static void flux_transformer_preload_join(flux_tf_preload_task_t *task) { + if (!task || !task->started) return; + pthread_join(task->thread, NULL); + task->started = 0; +} +#else +typedef struct { int started; int ok; } flux_tf_preload_task_t; +static void flux_transformer_preload_begin(flux_ctx *ctx, flux_tf_preload_task_t *task) { + (void)ctx; + if (task) memset(task, 0, sizeof(*task)); +} +static void flux_transformer_preload_join(flux_tf_preload_task_t *task) { (void)task; } +#endif + /* Get transformer for debugging */ void *flux_get_transformer(flux_ctx *ctx) { return ctx ? ctx->transformer : NULL; @@ -432,10 +498,16 @@ flux_image *flux_generate(flux_ctx *ctx, const char *prompt, return NULL; } + /* Optional overlap: load transformer in parallel while text is encoding. + * This is disabled by default and enabled via FLUX_OVERLAP_PRELOAD=1. */ + flux_tf_preload_task_t tf_preload; + flux_transformer_preload_begin(ctx, &tf_preload); + /* Encode text (and unconditioned text for CFG in base model) */ int text_seq; float *text_emb = flux_encode_text(ctx, prompt, &text_seq); if (!text_emb) { + flux_transformer_preload_join(&tf_preload); set_error("Failed to encode prompt"); return NULL; } @@ -445,14 +517,19 @@ flux_image *flux_generate(flux_ctx *ctx, const char *prompt, if (!ctx->is_distilled) { text_emb_uncond = flux_encode_text(ctx, "", &text_seq_uncond); if (!text_emb_uncond) { + flux_transformer_preload_join(&tf_preload); free(text_emb); set_error("Failed to encode empty prompt for CFG"); return NULL; } } - /* Release text encoder to free ~8GB before loading transformer */ - flux_release_text_encoder(ctx); + /* Ensure any async preload attempt is complete before load checks/fallback. */ + flux_transformer_preload_join(&tf_preload); + + /* Release text encoder only before first transformer load. + * Once transformer is loaded, keeping encoder avoids reload cost on later calls. */ + if (!ctx->transformer) flux_release_text_encoder(ctx); /* Load transformer now (after text encoder is freed to reduce peak memory) */ if (!flux_load_transformer_if_needed(ctx)) { @@ -830,10 +907,15 @@ flux_image *flux_img2img(flux_ctx *ctx, const char *prompt, if (p.num_steps <= 0) p.num_steps = ctx->default_steps; float guidance = (p.guidance > 0) ? p.guidance : ctx->default_guidance; + /* Optional overlap: load transformer in parallel while text is encoding. */ + flux_tf_preload_task_t tf_preload; + flux_transformer_preload_begin(ctx, &tf_preload); + /* Encode text */ int text_seq; float *text_emb = flux_encode_text(ctx, prompt, &text_seq); if (!text_emb) { + flux_transformer_preload_join(&tf_preload); if (resized) flux_image_free(resized); set_error("Failed to encode prompt"); return NULL; @@ -844,6 +926,7 @@ flux_image *flux_img2img(flux_ctx *ctx, const char *prompt, if (!ctx->is_distilled) { text_emb_uncond = flux_encode_text(ctx, "", &text_seq_uncond); if (!text_emb_uncond) { + flux_transformer_preload_join(&tf_preload); free(text_emb); if (resized) flux_image_free(resized); set_error("Failed to encode empty prompt for CFG"); @@ -851,8 +934,11 @@ flux_image *flux_img2img(flux_ctx *ctx, const char *prompt, } } - /* Release text encoder to free ~8GB before loading transformer */ - flux_release_text_encoder(ctx); + flux_transformer_preload_join(&tf_preload); + + /* Release text encoder only before first transformer load. + * Once transformer is loaded, keeping encoder avoids reload cost on later calls. */ + if (!ctx->transformer) flux_release_text_encoder(ctx); /* Load transformer now (after text encoder is freed to reduce peak memory) */ if (!flux_load_transformer_if_needed(ctx)) { @@ -1013,10 +1099,15 @@ flux_image *flux_multiref(flux_ctx *ctx, const char *prompt, if (p.num_steps <= 0) p.num_steps = ctx->default_steps; float guidance = (p.guidance > 0) ? p.guidance : ctx->default_guidance; + /* Optional overlap: load transformer in parallel while text is encoding. */ + flux_tf_preload_task_t tf_preload; + flux_transformer_preload_begin(ctx, &tf_preload); + /* Encode text */ int text_seq; float *text_emb = flux_encode_text(ctx, prompt, &text_seq); if (!text_emb) { + flux_transformer_preload_join(&tf_preload); set_error("Failed to encode prompt"); return NULL; } @@ -1026,13 +1117,18 @@ flux_image *flux_multiref(flux_ctx *ctx, const char *prompt, if (!ctx->is_distilled) { text_emb_uncond = flux_encode_text(ctx, "", &text_seq_uncond); if (!text_emb_uncond) { + flux_transformer_preload_join(&tf_preload); free(text_emb); set_error("Failed to encode empty prompt for CFG"); return NULL; } } - flux_release_text_encoder(ctx); + flux_transformer_preload_join(&tf_preload); + + /* Release text encoder only before first transformer load. + * Once transformer is loaded, keeping encoder avoids reload cost on later calls. */ + if (!ctx->transformer) flux_release_text_encoder(ctx); if (!flux_load_transformer_if_needed(ctx)) { free(text_emb); diff --git a/flux_cuda.cu b/flux_cuda.cu new file mode 100644 index 0000000..ea09824 --- /dev/null +++ b/flux_cuda.cu @@ -0,0 +1,1506 @@ +/* + * FLUX CUDA Attention Helpers + * + * CUDA implementation of attention operations used by the CUDA backend. + * Keeps Q/K/V, attention scores, softmax, and output projection on GPU. + */ + +#include "flux_cuda.h" + +#include +#include +#include +#include + +#include +#include + +/* ------------------------------------------------------------------------- + * Global CUDA state (single-process singleton) + * ------------------------------------------------------------------------- */ + +static cublasHandle_t g_cublas = NULL; +static cublasLtHandle_t g_cublas_lt = NULL; +static cudaStream_t g_stream = NULL; + +static float *g_d_q = NULL; +static float *g_d_k = NULL; +static float *g_d_v = NULL; +static float *g_d_scores = NULL; +static float *g_d_out = NULL; +static float *g_d_q_shd = NULL; +static float *g_d_k_shd = NULL; +static float *g_d_v_shd = NULL; +static float *g_d_out_shd = NULL; +static __nv_bfloat16 *g_d_q_bf16 = NULL; +static __nv_bfloat16 *g_d_k_bf16 = NULL; +static int *g_d_mask = NULL; +static void *g_d_workspace = NULL; + +static size_t g_d_q_bytes = 0; +static size_t g_d_k_bytes = 0; +static size_t g_d_v_bytes = 0; +static size_t g_d_scores_bytes = 0; +static size_t g_d_out_bytes = 0; +static size_t g_d_q_shd_bytes = 0; +static size_t g_d_k_shd_bytes = 0; +static size_t g_d_v_shd_bytes = 0; +static size_t g_d_out_shd_bytes = 0; +static size_t g_d_q_bf16_bytes = 0; +static size_t g_d_k_bf16_bytes = 0; +static size_t g_d_mask_bytes = 0; +static size_t g_d_workspace_bytes = 0; + +static int g_cuda_ready = -1; +static int g_warned = 0; +static int g_flash_mode = -1; + +static void flux_cuda_warn_once(const char *msg) { + if (!g_warned) { + fprintf(stderr, "%s\n", msg); + g_warned = 1; + } +} + +static int flux_cuda_flash_enabled(void) { + if (g_flash_mode == -1) { + g_flash_mode = getenv("FLUX_CUDA_FLASH_ATTN") ? 1 : 0; + } + return g_flash_mode; +} + +static void flux_cuda_cleanup(void) { + if (g_d_q) cudaFree(g_d_q); + if (g_d_k) cudaFree(g_d_k); + if (g_d_v) cudaFree(g_d_v); + if (g_d_scores) cudaFree(g_d_scores); + if (g_d_out) cudaFree(g_d_out); + if (g_d_q_shd) cudaFree(g_d_q_shd); + if (g_d_k_shd) cudaFree(g_d_k_shd); + if (g_d_v_shd) cudaFree(g_d_v_shd); + if (g_d_out_shd) cudaFree(g_d_out_shd); + if (g_d_q_bf16) cudaFree(g_d_q_bf16); + if (g_d_k_bf16) cudaFree(g_d_k_bf16); + if (g_d_mask) cudaFree(g_d_mask); + if (g_d_workspace) cudaFree(g_d_workspace); + + g_d_q = g_d_k = g_d_v = g_d_scores = g_d_out = NULL; + g_d_q_shd = g_d_k_shd = g_d_v_shd = g_d_out_shd = NULL; + g_d_q_bf16 = g_d_k_bf16 = NULL; + g_d_mask = NULL; + g_d_workspace = NULL; + + g_d_q_bytes = g_d_k_bytes = g_d_v_bytes = 0; + g_d_scores_bytes = g_d_out_bytes = 0; + g_d_q_shd_bytes = g_d_k_shd_bytes = g_d_v_shd_bytes = g_d_out_shd_bytes = 0; + g_d_q_bf16_bytes = g_d_k_bf16_bytes = 0; + g_d_mask_bytes = 0; + g_d_workspace_bytes = 0; + + if (g_cublas) { + cublasDestroy(g_cublas); + g_cublas = NULL; + } + if (g_cublas_lt) { + cublasLtDestroy(g_cublas_lt); + g_cublas_lt = NULL; + } + g_stream = NULL; +} + +static int flux_cuda_ensure_init(void) { + if (g_cuda_ready != -1) { + return g_cuda_ready; + } + + g_cuda_ready = 0; + int device_count = 0; + cudaError_t cuda_err = cudaGetDeviceCount(&device_count); + if (cuda_err != cudaSuccess || device_count <= 0) { + return 0; + } + + if (cublasCreate(&g_cublas) != CUBLAS_STATUS_SUCCESS) { + flux_cuda_warn_once("CUDA attention: failed to create cuBLAS handle"); + flux_cuda_cleanup(); + return 0; + } + if (cublasLtCreate(&g_cublas_lt) != CUBLAS_STATUS_SUCCESS) { + flux_cuda_warn_once("CUDA attention: failed to create cuBLASLt handle"); + flux_cuda_cleanup(); + return 0; + } + if (cublasSetPointerMode(g_cublas, CUBLAS_POINTER_MODE_HOST) != CUBLAS_STATUS_SUCCESS) { + flux_cuda_warn_once("CUDA attention: failed to configure cuBLAS pointer mode"); + flux_cuda_cleanup(); + return 0; + } + if (cublasSetStream(g_cublas, g_stream) != CUBLAS_STATUS_SUCCESS) { + flux_cuda_warn_once("CUDA attention: failed to configure cuBLAS stream"); + flux_cuda_cleanup(); + return 0; + } + + atexit(flux_cuda_cleanup); + g_cuda_ready = 1; + return 1; +} + +int flux_cuda_ops_set_stream(void *stream_handle) { + g_stream = (cudaStream_t)stream_handle; + if (!flux_cuda_ensure_init()) return 0; + return cublasSetStream(g_cublas, g_stream) == CUBLAS_STATUS_SUCCESS; +} + +static int flux_cuda_ensure_buffer(void **buf, size_t *cap_bytes, size_t need_bytes) { + if (*cap_bytes >= need_bytes) { + return 1; + } + + size_t new_cap = need_bytes; + if (*cap_bytes > 0) { + new_cap = *cap_bytes; + while (new_cap < need_bytes) { + new_cap *= 2; + } + } + + void *new_buf = NULL; + if (cudaMalloc(&new_buf, new_cap) != cudaSuccess) { + flux_cuda_warn_once("CUDA attention: device allocation failed"); + return 0; + } + + if (*buf) cudaFree(*buf); + *buf = new_buf; + *cap_bytes = new_cap; + return 1; +} + +/* ------------------------------------------------------------------------- + * CUDA kernels + * ------------------------------------------------------------------------- */ + +__global__ static void f32_to_bf16_kernel(const float *in, __nv_bfloat16 *out, int n) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) out[i] = __float2bfloat16(in[i]); +} + +/* Row-wise masked softmax for attention scores. + * scores: [rows, cols] + * row_local = row % seq_q for causal masking in batched mode. + */ +__global__ static void masked_softmax_kernel(float *scores, + const int *mask, + int rows, int cols, + int seq_q, + int causal, + int use_mask) { + int row = blockIdx.x; + int tid = threadIdx.x; + int row_local = row % seq_q; + float *row_ptr = scores + (size_t)row * cols; + + extern __shared__ float sh[]; + + float local_max = -1e30f; + int local_valid = 0; + for (int c = tid; c < cols; c += blockDim.x) { + int valid = 1; + if (causal && c > row_local) valid = 0; + if (use_mask && mask[c] == 0) valid = 0; + if (valid) { + float v = row_ptr[c]; + local_max = fmaxf(local_max, v); + local_valid = 1; + } + } + + sh[tid] = local_valid ? local_max : -1e30f; + __syncthreads(); + + for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) { + if (tid < stride) { + sh[tid] = fmaxf(sh[tid], sh[tid + stride]); + } + __syncthreads(); + } + + float max_val = sh[0]; + + float local_sum = 0.0f; + for (int c = tid; c < cols; c += blockDim.x) { + int valid = 1; + if (causal && c > row_local) valid = 0; + if (use_mask && mask[c] == 0) valid = 0; + + float e = 0.0f; + if (valid) { + e = expf(row_ptr[c] - max_val); + } + row_ptr[c] = e; + local_sum += e; + } + + sh[tid] = local_sum; + __syncthreads(); + + for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) { + if (tid < stride) { + sh[tid] += sh[tid + stride]; + } + __syncthreads(); + } + + float sum_val = sh[0]; + if (sum_val <= 0.0f) { + for (int c = tid; c < cols; c += blockDim.x) { + row_ptr[c] = 0.0f; + } + return; + } + + float inv_sum = 1.0f / sum_val; + for (int c = tid; c < cols; c += blockDim.x) { + row_ptr[c] *= inv_sum; + } +} + +/* Experimental flash-attention style fused kernel for SHD layout. + * Computes out = softmax(scale * q @ k^T) @ v without materializing scores. + * Constraints: no causal/mask, head_dim <= 256. */ +__global__ static void flash_attn_shd_kernel(float *out, + const float *q, + const float *k, + const float *v, + int heads, int seq_q, int seq_k, + int head_dim, float scale) { + int row_head = blockIdx.x; /* [0, heads * seq_q) */ + int h = row_head / seq_q; + int qi = row_head % seq_q; + int tid = threadIdx.x; + + if (h >= heads || qi >= seq_q) return; + + const float *q_row = q + ((size_t)qi * heads + h) * head_dim; + float out_acc = 0.0f; + + extern __shared__ float sh[]; + float *sh_dot = sh; + __shared__ float running_max; + __shared__ float running_sum; + __shared__ float corr; + __shared__ float weight; + __shared__ int new_max; + + if (tid == 0) { + running_max = -1e30f; + running_sum = 0.0f; + corr = 1.0f; + weight = 0.0f; + new_max = 0; + } + __syncthreads(); + + for (int kj = 0; kj < seq_k; kj++) { + float partial = 0.0f; + if (tid < head_dim) { + partial = q_row[tid] * k[((size_t)kj * heads + h) * head_dim + tid]; + } + sh_dot[tid] = partial; + __syncthreads(); + + for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) { + if (tid < stride) { + sh_dot[tid] += sh_dot[tid + stride]; + } + __syncthreads(); + } + + if (tid == 0) { + float score = sh_dot[0] * scale; + if (score > running_max) { + corr = expf(running_max - score); + running_sum = running_sum * corr + 1.0f; + running_max = score; + weight = 1.0f; + new_max = 1; + } else { + corr = 1.0f; + weight = expf(score - running_max); + running_sum += weight; + new_max = 0; + } + } + __syncthreads(); + + if (tid < head_dim) { + float vv = v[((size_t)kj * heads + h) * head_dim + tid]; + if (new_max) { + out_acc = out_acc * corr + vv; + } else { + out_acc += weight * vv; + } + } + __syncthreads(); + } + + if (tid < head_dim) { + float inv = (running_sum > 0.0f) ? (1.0f / running_sum) : 0.0f; + out[((size_t)qi * heads + h) * head_dim + tid] = out_acc * inv; + } +} + +/* AdaLN: out = (1 + scale) * LN(x) + shift, with shift/scale shared across seq. */ +__global__ static void adaln_norm_kernel(float *out, const float *x, + const float *shift, const float *scale, + int seq, int hidden, float eps) { + int row = blockIdx.x; + int tid = threadIdx.x; + if (row >= seq) return; + + const float *x_row = x + (size_t)row * hidden; + float *o_row = out + (size_t)row * hidden; + + extern __shared__ float sh[]; + float *sh_sum = sh; + float *sh_sq = sh + blockDim.x; + + float local_sum = 0.0f; + float local_sq = 0.0f; + for (int i = tid; i < hidden; i += blockDim.x) { + float v = x_row[i]; + local_sum += v; + local_sq += v * v; + } + sh_sum[tid] = local_sum; + sh_sq[tid] = local_sq; + __syncthreads(); + + for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) { + if (tid < stride) { + sh_sum[tid] += sh_sum[tid + stride]; + sh_sq[tid] += sh_sq[tid + stride]; + } + __syncthreads(); + } + + float mean = sh_sum[0] / (float)hidden; + float var = sh_sq[0] / (float)hidden - mean * mean; + if (var < 0.0f) var = 0.0f; + float inv = rsqrtf(var + eps); + + for (int i = tid; i < hidden; i += blockDim.x) { + float norm = (x_row[i] - mean) * inv; + o_row[i] = (1.0f + scale[i]) * norm + shift[i]; + } +} + +/* Split fused projection output into Q/K/V + MLP gate/up. */ +__global__ static void split_qkv_mlp_kernel(const float *fused, + float *q, float *k, float *v, + float *gate, float *up, + int seq, int hidden, int mlp_hidden) { + int row = blockIdx.x; + int tid = threadIdx.x; + if (row >= seq) return; + + int fused_dim = hidden * 3 + mlp_hidden * 2; + const float *src = fused + (size_t)row * fused_dim; + + float *q_row = q + (size_t)row * hidden; + float *k_row = k + (size_t)row * hidden; + float *v_row = v + (size_t)row * hidden; + for (int i = tid; i < hidden; i += blockDim.x) { + q_row[i] = src[i]; + k_row[i] = src[hidden + i]; + v_row[i] = src[hidden * 2 + i]; + } + + float *g_row = gate + (size_t)row * mlp_hidden; + float *u_row = up + (size_t)row * mlp_hidden; + int off = hidden * 3; + for (int i = tid; i < mlp_hidden; i += blockDim.x) { + g_row[i] = src[off + i]; + u_row[i] = src[off + mlp_hidden + i]; + } +} + +/* Per-head RMSNorm for Q/K with shared head weights. */ +__global__ static void qk_rms_norm_kernel(float *q, float *k, + const float *q_weight, const float *k_weight, + int rows, int head_dim, float eps) { + int row = blockIdx.x; + int tid = threadIdx.x; + if (row >= rows) return; + + float *q_row = q + (size_t)row * head_dim; + float *k_row = k + (size_t)row * head_dim; + + extern __shared__ float sh[]; + float *sh_q = sh; + float *sh_k = sh + blockDim.x; + + float sq_q = 0.0f; + float sq_k = 0.0f; + for (int d = tid; d < head_dim; d += blockDim.x) { + float qv = q_row[d]; + float kv = k_row[d]; + sq_q += qv * qv; + sq_k += kv * kv; + } + sh_q[tid] = sq_q; + sh_k[tid] = sq_k; + __syncthreads(); + + for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) { + if (tid < stride) { + sh_q[tid] += sh_q[tid + stride]; + sh_k[tid] += sh_k[tid + stride]; + } + __syncthreads(); + } + + float inv_q = rsqrtf(sh_q[0] / (float)head_dim + eps); + float inv_k = rsqrtf(sh_k[0] / (float)head_dim + eps); + + for (int d = tid; d < head_dim; d += blockDim.x) { + q_row[d] = q_row[d] * inv_q * q_weight[d]; + k_row[d] = k_row[d] * inv_k * k_weight[d]; + } +} + +/* Apply unified RoPE for text (prefix) + image (suffix) tokens. */ +__global__ static void rope_unified_kernel(float *q, float *k, + const float *txt_cos, const float *txt_sin, + const float *img_cos, const float *img_sin, + int seq, int img_offset, + int heads, int head_dim) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int pairs = head_dim / 2; + int total = seq * heads * pairs; + if (idx >= total) return; + + int pair = idx % pairs; + int t = idx / pairs; + int h = t % heads; + int s = t / heads; + int d = pair * 2; + + const float *cos_row; + const float *sin_row; + if (s < img_offset) { + cos_row = txt_cos + (size_t)s * head_dim; + sin_row = txt_sin + (size_t)s * head_dim; + } else { + int img_s = s - img_offset; + cos_row = img_cos + (size_t)img_s * head_dim; + sin_row = img_sin + (size_t)img_s * head_dim; + } + + size_t base = ((size_t)s * heads + h) * head_dim; + float *qv = q + base; + float *kv = k + base; + + float c = cos_row[d]; + float sn = sin_row[d]; + + float q0 = qv[d]; + float q1 = qv[d + 1]; + qv[d] = q0 * c - q1 * sn; + qv[d + 1] = q1 * c + q0 * sn; + + float k0 = kv[d]; + float k1 = kv[d + 1]; + kv[d] = k0 * c - k1 * sn; + kv[d + 1] = k1 * c + k0 * sn; +} + +__global__ static void silu_mul_kernel(float *gate, const float *up, int n) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= n) return; + float g = gate[i]; + float silu = g / (1.0f + expf(-g)); + gate[i] = silu * up[i]; +} + +__global__ static void concat_attn_mlp_kernel(const float *attn, const float *mlp, float *out, + int seq, int hidden, int mlp_hidden) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int row_dim = hidden + mlp_hidden; + int total = seq * row_dim; + if (idx >= total) return; + + int row = idx / row_dim; + int col = idx % row_dim; + if (col < hidden) { + out[idx] = attn[(size_t)row * hidden + col]; + } else { + out[idx] = mlp[(size_t)row * mlp_hidden + (col - hidden)]; + } +} + +__global__ static void concat_seq_kernel(float *out, const float *a, const float *b, + int seq_a, int seq_b, int hidden) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = (seq_a + seq_b) * hidden; + if (idx >= total) return; + + int row = idx / hidden; + int col = idx % hidden; + if (row < seq_a) { + out[idx] = a[(size_t)row * hidden + col]; + } else { + int rb = row - seq_a; + out[idx] = b[(size_t)rb * hidden + col]; + } +} + +__global__ static void gated_add_kernel(float *hidden, const float *gate, const float *proj, + int seq, int hidden_dim) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = seq * hidden_dim; + if (idx >= total) return; + int col = idx % hidden_dim; + hidden[idx] += gate[col] * proj[idx]; +} + +__global__ static void im2col_nchw_rows_kernel(float *col, const float *in, + int in_ch, int H, int W, + int kH, int kW, int stride, int padding, + int outW, int row_offset, int tile_h, + int K) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int tile_pixels = tile_h * outW; + int total = tile_pixels * K; + if (idx >= total) return; + + int pix = idx / K; + int k = idx % K; + + int oh_rel = pix / outW; + int ow = pix % outW; + int oh = row_offset + oh_rel; + + int ic = k / (kH * kW); + int rem = k % (kH * kW); + int kh = rem / kW; + int kw = rem % kW; + + int ih = oh * stride - padding + kh; + int iw = ow * stride - padding + kw; + + float v = 0.0f; + if (ih >= 0 && ih < H && iw >= 0 && iw < W) { + v = in[(size_t)ic * H * W + (size_t)ih * W + iw]; + } + + col[idx] = v; +} + +__global__ static void add_bias_rows_kernel(float *rows, const float *bias, + int rows_count, int cols) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = rows_count * cols; + if (idx >= total) return; + int col = idx % cols; + rows[idx] += bias[col]; +} + +__global__ static void rows_to_nchw_tile_kernel(float *out, const float *rows, + int out_ch, int outH, int outW, + int row_offset, int tile_h) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int tile_pixels = tile_h * outW; + int total = out_ch * tile_pixels; + if (idx >= total) return; + + int oc = idx / tile_pixels; + int pix = idx % tile_pixels; + int oh = row_offset + (pix / outW); + int ow = pix % outW; + + out[(size_t)oc * outH * outW + (size_t)oh * outW + ow] = rows[(size_t)pix * out_ch + oc]; +} + +__global__ static void nchw_to_rows_kernel(float *rows, const float *x, + int channels, int H, int W) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int seq = H * W; + int total = seq * channels; + if (idx >= total) return; + + int s = idx / channels; + int c = idx % channels; + rows[idx] = x[(size_t)c * seq + s]; +} + +__global__ static void rows_to_nchw_kernel(float *x, const float *rows, + int channels, int H, int W) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int seq = H * W; + int total = seq * channels; + if (idx >= total) return; + + int s = idx / channels; + int c = idx % channels; + x[(size_t)c * seq + s] = rows[idx]; +} + +__global__ static void group_norm_nchw_kernel(float *out, const float *x, + const float *gamma, const float *beta, + int batch, int channels, int spatial, + int num_groups, int channels_per_group, + float eps) { + int bg = blockIdx.x; + int b = bg / num_groups; + int g = bg % num_groups; + if (b >= batch) return; + + int c_start = g * channels_per_group; + int total = channels_per_group * spatial; + int tid = threadIdx.x; + + extern __shared__ float sh[]; + float *sh_sum = sh; + float *sh_sumsq = sh + blockDim.x; + + float local_sum = 0.0f; + float local_sumsq = 0.0f; + + for (int idx = tid; idx < total; idx += blockDim.x) { + int c = c_start + idx / spatial; + int s = idx % spatial; + float v = x[((size_t)b * channels + c) * spatial + s]; + local_sum += v; + local_sumsq += v * v; + } + + sh_sum[tid] = local_sum; + sh_sumsq[tid] = local_sumsq; + __syncthreads(); + + for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) { + if (tid < stride) { + sh_sum[tid] += sh_sum[tid + stride]; + sh_sumsq[tid] += sh_sumsq[tid + stride]; + } + __syncthreads(); + } + + float mean = sh_sum[0] / (float)total; + float var = sh_sumsq[0] / (float)total - mean * mean; + float inv_std = rsqrtf(var + eps); + + for (int idx = tid; idx < total; idx += blockDim.x) { + int c = c_start + idx / spatial; + int s = idx % spatial; + size_t off = ((size_t)b * channels + c) * spatial + s; + float v = (x[off] - mean) * inv_std; + out[off] = gamma[c] * v + beta[c]; + } +} + +__global__ static void silu_kernel(float *x, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= n) return; + float v = x[idx]; + x[idx] = v / (1.0f + expf(-v)); +} + +__global__ static void add_inplace_kernel(float *a, const float *b, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= n) return; + a[idx] += b[idx]; +} + +__global__ static void upsample_nearest2x_nchw_kernel(float *out, const float *in, + int channels, int H, int W) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int outH = H * 2; + int outW = W * 2; + int total = channels * outH * outW; + if (idx >= total) return; + + int c = idx / (outH * outW); + int rem = idx % (outH * outW); + int oh = rem / outW; + int ow = rem % outW; + + int ih = oh >> 1; + int iw = ow >> 1; + + out[(size_t)c * outH * outW + (size_t)oh * outW + ow] = + in[(size_t)c * H * W + (size_t)ih * W + iw]; +} + +/* ------------------------------------------------------------------------- + * GEMM helpers (row-major wrappers) + * ------------------------------------------------------------------------- */ + +/* Row-major scores = scale * q @ k^T + * q: [seq_q, head_dim], k: [seq_k, head_dim], scores: [seq_q, seq_k] + */ +static int flux_cuda_qk_matmul_f32(float *d_scores, + const float *d_q, + const float *d_k, + int seq_q, int seq_k, int head_dim, + float scale) { + const float beta = 0.0f; + cublasStatus_t st = cublasSgemm(g_cublas, + CUBLAS_OP_T, CUBLAS_OP_N, + seq_k, seq_q, head_dim, + &scale, + d_k, head_dim, + d_q, head_dim, + &beta, + d_scores, seq_k); + return st == CUBLAS_STATUS_SUCCESS; +} + +/* Row-major out = scores @ v + * scores: [seq_q, seq_k], v: [seq_k, head_dim], out: [seq_q, head_dim] + */ +static int flux_cuda_pv_matmul_f32(float *d_out, + const float *d_scores, + const float *d_v, + int seq_q, int seq_k, int head_dim) { + const float alpha = 1.0f; + const float beta = 0.0f; + cublasStatus_t st = cublasSgemm(g_cublas, + CUBLAS_OP_N, CUBLAS_OP_N, + head_dim, seq_q, seq_k, + &alpha, + d_v, head_dim, + d_scores, seq_k, + &beta, + d_out, head_dim); + return st == CUBLAS_STATUS_SUCCESS; +} + +/* Row-major batched scores = scale * q @ k^T + * q/k/scores are head-major: [heads, seq, dim] + */ +static int flux_cuda_qk_matmul_f32_batched(float *d_scores, + const float *d_q, + const float *d_k, + int heads, int seq_q, int seq_k, int head_dim, + float scale) { + const float beta = 0.0f; + long long stride_q = (long long)seq_q * head_dim; + long long stride_k = (long long)seq_k * head_dim; + long long stride_scores = (long long)seq_q * seq_k; + + cublasStatus_t st = cublasSgemmStridedBatched(g_cublas, + CUBLAS_OP_T, CUBLAS_OP_N, + seq_k, seq_q, head_dim, + &scale, + d_k, head_dim, stride_k, + d_q, head_dim, stride_q, + &beta, + d_scores, seq_k, stride_scores, + heads); + return st == CUBLAS_STATUS_SUCCESS; +} + +/* Row-major batched out = scores @ v */ +static int flux_cuda_pv_matmul_f32_batched(float *d_out, + const float *d_scores, + const float *d_v, + int heads, int seq_q, int seq_k, int head_dim) { + const float alpha = 1.0f; + const float beta = 0.0f; + long long stride_v = (long long)seq_k * head_dim; + long long stride_scores = (long long)seq_q * seq_k; + long long stride_out = (long long)seq_q * head_dim; + + cublasStatus_t st = cublasSgemmStridedBatched(g_cublas, + CUBLAS_OP_N, CUBLAS_OP_N, + head_dim, seq_q, seq_k, + &alpha, + d_v, head_dim, stride_v, + d_scores, seq_k, stride_scores, + &beta, + d_out, head_dim, stride_out, + heads); + return st == CUBLAS_STATUS_SUCCESS; +} + +/* Row-major batched scores = scale * q @ k^T + * q/k are sequence-major [seq, heads, dim] interleaved by head. + * scores are contiguous head-major [heads, seq_q, seq_k]. + */ +static int flux_cuda_qk_matmul_f32_batched_shd(float *d_scores, + const float *d_q_shd, + const float *d_k_shd, + int heads, int seq_q, int seq_k, + int hidden, int head_dim, + float scale) { + const float beta = 0.0f; + long long stride_qk = (long long)head_dim; /* interleaved heads */ + long long stride_scores = (long long)seq_q * seq_k; /* contiguous per head */ + + cublasStatus_t st = cublasSgemmStridedBatched(g_cublas, + CUBLAS_OP_T, CUBLAS_OP_N, + seq_k, seq_q, head_dim, + &scale, + d_k_shd, hidden, stride_qk, + d_q_shd, hidden, stride_qk, + &beta, + d_scores, seq_k, stride_scores, + heads); + return st == CUBLAS_STATUS_SUCCESS; +} + +/* Row-major batched out = scores @ v + * scores are contiguous head-major [heads, seq_q, seq_k]. + * v/out are sequence-major [seq, heads, dim] interleaved by head. + */ +static int flux_cuda_pv_matmul_f32_batched_shd(float *d_out_shd, + const float *d_scores, + const float *d_v_shd, + int heads, int seq_q, int seq_k, + int hidden, int head_dim) { + const float alpha = 1.0f; + const float beta = 0.0f; + long long stride_v_out = (long long)head_dim; /* interleaved heads */ + long long stride_scores = (long long)seq_q * seq_k; /* contiguous per head */ + + cublasStatus_t st = cublasSgemmStridedBatched(g_cublas, + CUBLAS_OP_N, CUBLAS_OP_N, + head_dim, seq_q, seq_k, + &alpha, + d_v_shd, hidden, stride_v_out, + d_scores, seq_k, stride_scores, + &beta, + d_out_shd, hidden, stride_v_out, + heads); + return st == CUBLAS_STATUS_SUCCESS; +} + +/* BF16 tensor-core QK path via cuBLASLt (single-head). */ +static int flux_cuda_qk_matmul_bf16_lt(float *d_scores, + const float *d_q, + const float *d_k, + __nv_bfloat16 *d_q_bf16, + __nv_bfloat16 *d_k_bf16, + int seq_q, int seq_k, int head_dim, + float scale) { + int q_elems = seq_q * head_dim; + int k_elems = seq_k * head_dim; + + int threads = 256; + int q_blocks = (q_elems + threads - 1) / threads; + int k_blocks = (k_elems + threads - 1) / threads; + f32_to_bf16_kernel<<>>(d_q, d_q_bf16, q_elems); + f32_to_bf16_kernel<<>>(d_k, d_k_bf16, k_elems); + if (cudaGetLastError() != cudaSuccess) { + return 0; + } + + cublasLtMatmulDesc_t op_desc = NULL; + cublasLtMatrixLayout_t a_layout = NULL; + cublasLtMatrixLayout_t b_layout = NULL; + cublasLtMatrixLayout_t c_layout = NULL; + cublasLtMatmulPreference_t pref = NULL; + cublasLtMatmulHeuristicResult_t heuristic; + int num_results = 0; + + cublasOperation_t trans_a = CUBLAS_OP_N; + cublasOperation_t trans_b = CUBLAS_OP_T; + cublasLtOrder_t order = CUBLASLT_ORDER_ROW; + +#ifdef CUBLAS_COMPUTE_32F_FAST_16BF + cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_16BF; +#else + cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; +#endif + + cublasStatus_t st = cublasLtMatmulDescCreate(&op_desc, compute_type, CUDA_R_32F); + if (st != CUBLAS_STATUS_SUCCESS) goto fail; + + st = cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_TRANSA, + &trans_a, sizeof(trans_a)); + if (st != CUBLAS_STATUS_SUCCESS) goto fail; + st = cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_TRANSB, + &trans_b, sizeof(trans_b)); + if (st != CUBLAS_STATUS_SUCCESS) goto fail; + + st = cublasLtMatrixLayoutCreate(&a_layout, CUDA_R_16BF, seq_q, head_dim, head_dim); + if (st != CUBLAS_STATUS_SUCCESS) goto fail; + st = cublasLtMatrixLayoutCreate(&b_layout, CUDA_R_16BF, seq_k, head_dim, head_dim); + if (st != CUBLAS_STATUS_SUCCESS) goto fail; + st = cublasLtMatrixLayoutCreate(&c_layout, CUDA_R_32F, seq_q, seq_k, seq_k); + if (st != CUBLAS_STATUS_SUCCESS) goto fail; + + st = cublasLtMatrixLayoutSetAttribute(a_layout, CUBLASLT_MATRIX_LAYOUT_ORDER, + &order, sizeof(order)); + if (st != CUBLAS_STATUS_SUCCESS) goto fail; + st = cublasLtMatrixLayoutSetAttribute(b_layout, CUBLASLT_MATRIX_LAYOUT_ORDER, + &order, sizeof(order)); + if (st != CUBLAS_STATUS_SUCCESS) goto fail; + st = cublasLtMatrixLayoutSetAttribute(c_layout, CUBLASLT_MATRIX_LAYOUT_ORDER, + &order, sizeof(order)); + if (st != CUBLAS_STATUS_SUCCESS) goto fail; + + st = cublasLtMatmulPreferenceCreate(&pref); + if (st != CUBLAS_STATUS_SUCCESS) goto fail; + + st = cublasLtMatmulPreferenceSetAttribute(pref, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &g_d_workspace_bytes, sizeof(g_d_workspace_bytes)); + if (st != CUBLAS_STATUS_SUCCESS) goto fail; + + st = cublasLtMatmulAlgoGetHeuristic(g_cublas_lt, op_desc, + a_layout, b_layout, c_layout, c_layout, + pref, 1, &heuristic, &num_results); + if (st != CUBLAS_STATUS_SUCCESS || num_results == 0) goto fail; + + { + float beta = 0.0f; + st = cublasLtMatmul(g_cublas_lt, op_desc, + &scale, + d_q_bf16, a_layout, + d_k_bf16, b_layout, + &beta, + d_scores, c_layout, + d_scores, c_layout, + &heuristic.algo, + g_d_workspace, g_d_workspace_bytes, + g_stream); + if (st != CUBLAS_STATUS_SUCCESS) goto fail; + } + + cublasLtMatmulPreferenceDestroy(pref); + cublasLtMatrixLayoutDestroy(c_layout); + cublasLtMatrixLayoutDestroy(b_layout); + cublasLtMatrixLayoutDestroy(a_layout); + cublasLtMatmulDescDestroy(op_desc); + return 1; + +fail: + if (pref) cublasLtMatmulPreferenceDestroy(pref); + if (c_layout) cublasLtMatrixLayoutDestroy(c_layout); + if (b_layout) cublasLtMatrixLayoutDestroy(b_layout); + if (a_layout) cublasLtMatrixLayoutDestroy(a_layout); + if (op_desc) cublasLtMatmulDescDestroy(op_desc); + return 0; +} + +/* ------------------------------------------------------------------------- + * Public API + * ------------------------------------------------------------------------- */ + +int flux_cuda_attention_single(float *out, + const float *q, const float *k, const float *v, + int seq_q, int seq_k, int head_dim, + float scale, int causal, + const int *attention_mask, + int prefer_bf16) { + if (!out || !q || !k || !v) return 0; + if (seq_q <= 0 || seq_k <= 0 || head_dim <= 0) return 0; + if (!flux_cuda_ensure_init()) return 0; + + size_t q_bytes = (size_t)seq_q * head_dim * sizeof(float); + size_t k_bytes = (size_t)seq_k * head_dim * sizeof(float); + size_t v_bytes = (size_t)seq_k * head_dim * sizeof(float); + size_t scores_bytes = (size_t)seq_q * seq_k * sizeof(float); + size_t out_bytes = (size_t)seq_q * head_dim * sizeof(float); + size_t q_bf16_bytes = (size_t)seq_q * head_dim * sizeof(__nv_bfloat16); + size_t k_bf16_bytes = (size_t)seq_k * head_dim * sizeof(__nv_bfloat16); + size_t mask_bytes = (size_t)seq_k * sizeof(int); + size_t workspace_target = (size_t)8 * 1024 * 1024; + + if (!flux_cuda_ensure_buffer((void **)&g_d_q, &g_d_q_bytes, q_bytes) || + !flux_cuda_ensure_buffer((void **)&g_d_k, &g_d_k_bytes, k_bytes) || + !flux_cuda_ensure_buffer((void **)&g_d_v, &g_d_v_bytes, v_bytes) || + !flux_cuda_ensure_buffer((void **)&g_d_scores, &g_d_scores_bytes, scores_bytes) || + !flux_cuda_ensure_buffer((void **)&g_d_out, &g_d_out_bytes, out_bytes) || + !flux_cuda_ensure_buffer((void **)&g_d_q_bf16, &g_d_q_bf16_bytes, q_bf16_bytes) || + !flux_cuda_ensure_buffer((void **)&g_d_k_bf16, &g_d_k_bf16_bytes, k_bf16_bytes) || + !flux_cuda_ensure_buffer((void **)&g_d_workspace, &g_d_workspace_bytes, workspace_target)) { + return 0; + } + + if (attention_mask) { + if (!flux_cuda_ensure_buffer((void **)&g_d_mask, &g_d_mask_bytes, mask_bytes)) { + return 0; + } + if (cudaMemcpy(g_d_mask, attention_mask, mask_bytes, cudaMemcpyHostToDevice) != cudaSuccess) { + return 0; + } + } + + if (cudaMemcpy(g_d_q, q, q_bytes, cudaMemcpyHostToDevice) != cudaSuccess) return 0; + if (cudaMemcpy(g_d_k, k, k_bytes, cudaMemcpyHostToDevice) != cudaSuccess) return 0; + if (cudaMemcpy(g_d_v, v, v_bytes, cudaMemcpyHostToDevice) != cudaSuccess) return 0; + + int ok = 0; + if (prefer_bf16) { + ok = flux_cuda_qk_matmul_bf16_lt(g_d_scores, g_d_q, g_d_k, + g_d_q_bf16, g_d_k_bf16, + seq_q, seq_k, head_dim, scale); + } + if (!ok) { + ok = flux_cuda_qk_matmul_f32(g_d_scores, g_d_q, g_d_k, seq_q, seq_k, head_dim, scale); + } + if (!ok) return 0; + + { + int threads = 256; + int rows = seq_q; + size_t shmem = (size_t)threads * sizeof(float); + masked_softmax_kernel<<>>( + g_d_scores, g_d_mask, rows, seq_k, seq_q, + causal ? 1 : 0, attention_mask ? 1 : 0 + ); + if (cudaGetLastError() != cudaSuccess) return 0; + } + + if (!flux_cuda_pv_matmul_f32(g_d_out, g_d_scores, g_d_v, seq_q, seq_k, head_dim)) { + return 0; + } + + if (cudaMemcpy(out, g_d_out, out_bytes, cudaMemcpyDeviceToHost) != cudaSuccess) { + return 0; + } + + return 1; +} + +int flux_cuda_attention_batched(float *out, + const float *q, const float *k, const float *v, + int heads, int seq_q, int seq_k, int head_dim, + float scale, int causal, + const int *attention_mask) { + if (!out || !q || !k || !v) return 0; + if (heads <= 0 || seq_q <= 0 || seq_k <= 0 || head_dim <= 0) return 0; + if (!flux_cuda_ensure_init()) return 0; + + size_t q_bytes = (size_t)heads * seq_q * head_dim * sizeof(float); + size_t k_bytes = (size_t)heads * seq_k * head_dim * sizeof(float); + size_t v_bytes = (size_t)heads * seq_k * head_dim * sizeof(float); + size_t scores_bytes = (size_t)heads * seq_q * seq_k * sizeof(float); + size_t out_bytes = (size_t)heads * seq_q * head_dim * sizeof(float); + size_t mask_bytes = (size_t)seq_k * sizeof(int); + + if (!flux_cuda_ensure_buffer((void **)&g_d_q, &g_d_q_bytes, q_bytes) || + !flux_cuda_ensure_buffer((void **)&g_d_k, &g_d_k_bytes, k_bytes) || + !flux_cuda_ensure_buffer((void **)&g_d_v, &g_d_v_bytes, v_bytes) || + !flux_cuda_ensure_buffer((void **)&g_d_scores, &g_d_scores_bytes, scores_bytes) || + !flux_cuda_ensure_buffer((void **)&g_d_out, &g_d_out_bytes, out_bytes)) { + return 0; + } + + if (attention_mask) { + if (!flux_cuda_ensure_buffer((void **)&g_d_mask, &g_d_mask_bytes, mask_bytes)) { + return 0; + } + if (cudaMemcpy(g_d_mask, attention_mask, mask_bytes, cudaMemcpyHostToDevice) != cudaSuccess) { + return 0; + } + } + + if (cudaMemcpy(g_d_q, q, q_bytes, cudaMemcpyHostToDevice) != cudaSuccess) return 0; + if (cudaMemcpy(g_d_k, k, k_bytes, cudaMemcpyHostToDevice) != cudaSuccess) return 0; + if (cudaMemcpy(g_d_v, v, v_bytes, cudaMemcpyHostToDevice) != cudaSuccess) return 0; + + if (!flux_cuda_qk_matmul_f32_batched(g_d_scores, g_d_q, g_d_k, + heads, seq_q, seq_k, head_dim, scale)) { + return 0; + } + + { + int threads = 256; + int rows = heads * seq_q; + size_t shmem = (size_t)threads * sizeof(float); + masked_softmax_kernel<<>>( + g_d_scores, g_d_mask, rows, seq_k, seq_q, + causal ? 1 : 0, attention_mask ? 1 : 0 + ); + if (cudaGetLastError() != cudaSuccess) return 0; + } + + if (!flux_cuda_pv_matmul_f32_batched(g_d_out, g_d_scores, g_d_v, + heads, seq_q, seq_k, head_dim)) { + return 0; + } + + if (cudaMemcpy(out, g_d_out, out_bytes, cudaMemcpyDeviceToHost) != cudaSuccess) { + return 0; + } + + return 1; +} + +int flux_cuda_attention_batched_shd(float *out, + const float *q, const float *k, const float *v, + int heads, int seq_q, int seq_k, int head_dim, + float scale, int causal, + const int *attention_mask) { + if (!out || !q || !k || !v) return 0; + if (heads <= 0 || seq_q <= 0 || seq_k <= 0 || head_dim <= 0) return 0; + if (!flux_cuda_ensure_init()) return 0; + + size_t q_shd_bytes = (size_t)seq_q * heads * head_dim * sizeof(float); + size_t k_shd_bytes = (size_t)seq_k * heads * head_dim * sizeof(float); + size_t v_shd_bytes = (size_t)seq_k * heads * head_dim * sizeof(float); + size_t out_shd_bytes = (size_t)seq_q * heads * head_dim * sizeof(float); + size_t scores_bytes = (size_t)heads * seq_q * seq_k * sizeof(float); + size_t mask_bytes = (size_t)seq_k * sizeof(int); + int hidden = heads * head_dim; + + if (!flux_cuda_ensure_buffer((void **)&g_d_q_shd, &g_d_q_shd_bytes, q_shd_bytes) || + !flux_cuda_ensure_buffer((void **)&g_d_k_shd, &g_d_k_shd_bytes, k_shd_bytes) || + !flux_cuda_ensure_buffer((void **)&g_d_v_shd, &g_d_v_shd_bytes, v_shd_bytes) || + !flux_cuda_ensure_buffer((void **)&g_d_out_shd, &g_d_out_shd_bytes, out_shd_bytes) || + !flux_cuda_ensure_buffer((void **)&g_d_scores, &g_d_scores_bytes, scores_bytes)) { + return 0; + } + + if (attention_mask) { + if (!flux_cuda_ensure_buffer((void **)&g_d_mask, &g_d_mask_bytes, mask_bytes)) { + return 0; + } + if (cudaMemcpy(g_d_mask, attention_mask, mask_bytes, cudaMemcpyHostToDevice) != cudaSuccess) { + return 0; + } + } + + if (cudaMemcpy(g_d_q_shd, q, q_shd_bytes, cudaMemcpyHostToDevice) != cudaSuccess) return 0; + if (cudaMemcpy(g_d_k_shd, k, k_shd_bytes, cudaMemcpyHostToDevice) != cudaSuccess) return 0; + if (cudaMemcpy(g_d_v_shd, v, v_shd_bytes, cudaMemcpyHostToDevice) != cudaSuccess) return 0; + + if (!flux_cuda_qk_matmul_f32_batched_shd(g_d_scores, g_d_q_shd, g_d_k_shd, + heads, seq_q, seq_k, hidden, head_dim, scale)) { + return 0; + } + + { + int threads = 256; + int rows = heads * seq_q; + size_t shmem = (size_t)threads * sizeof(float); + masked_softmax_kernel<<>>( + g_d_scores, g_d_mask, rows, seq_k, seq_q, + causal ? 1 : 0, attention_mask ? 1 : 0 + ); + if (cudaGetLastError() != cudaSuccess) return 0; + } + + if (!flux_cuda_pv_matmul_f32_batched_shd(g_d_out_shd, g_d_scores, g_d_v_shd, + heads, seq_q, seq_k, hidden, head_dim)) { + return 0; + } + + if (cudaMemcpy(out, g_d_out_shd, out_shd_bytes, cudaMemcpyDeviceToHost) != cudaSuccess) { + return 0; + } + + return 1; +} + +int flux_cuda_attention_batched_shd_device(float *d_out, + const float *d_q, const float *d_k, const float *d_v, + int heads, int seq_q, int seq_k, int head_dim, + float scale, int causal, + const int *attention_mask) { + if (!d_out || !d_q || !d_k || !d_v) return 0; + if (heads <= 0 || seq_q <= 0 || seq_k <= 0 || head_dim <= 0) return 0; + if (!flux_cuda_ensure_init()) return 0; + + size_t scores_bytes = (size_t)heads * seq_q * seq_k * sizeof(float); + size_t mask_bytes = (size_t)seq_k * sizeof(int); + int hidden = heads * head_dim; + + if (!causal && !attention_mask && flux_cuda_flash_enabled() && head_dim > 0 && head_dim <= 256) { + int threads = 1; + while (threads < head_dim) threads <<= 1; + if (threads < 32) threads = 32; + if (threads > 256) threads = 256; + int blocks = heads * seq_q; + size_t shmem = (size_t)threads * sizeof(float); + flash_attn_shd_kernel<<>>( + d_out, d_q, d_k, d_v, heads, seq_q, seq_k, head_dim, scale + ); + return cudaGetLastError() == cudaSuccess; + } + + if (!flux_cuda_ensure_buffer((void **)&g_d_scores, &g_d_scores_bytes, scores_bytes)) { + return 0; + } + + if (attention_mask) { + if (!flux_cuda_ensure_buffer((void **)&g_d_mask, &g_d_mask_bytes, mask_bytes)) { + return 0; + } + if (cudaMemcpy(g_d_mask, attention_mask, mask_bytes, cudaMemcpyHostToDevice) != cudaSuccess) { + return 0; + } + } + + if (!flux_cuda_qk_matmul_f32_batched_shd(g_d_scores, d_q, d_k, + heads, seq_q, seq_k, hidden, head_dim, scale)) { + return 0; + } + + { + int threads = 256; + int rows = heads * seq_q; + size_t shmem = (size_t)threads * sizeof(float); + masked_softmax_kernel<<>>( + g_d_scores, g_d_mask, rows, seq_k, seq_q, + causal ? 1 : 0, attention_mask ? 1 : 0 + ); + if (cudaGetLastError() != cudaSuccess) return 0; + } + + if (!flux_cuda_pv_matmul_f32_batched_shd(d_out, g_d_scores, d_v, + heads, seq_q, seq_k, hidden, head_dim)) { + return 0; + } + + return 1; +} + +int flux_cuda_adaln_norm_device(float *d_out, const float *d_x, + const float *d_shift, const float *d_scale, + int seq, int hidden, float eps) { + if (!d_out || !d_x || !d_shift || !d_scale || seq <= 0 || hidden <= 0) return 0; + if (!flux_cuda_ensure_init()) return 0; + + int threads = 256; + size_t shmem = (size_t)threads * 2 * sizeof(float); + adaln_norm_kernel<<>>(d_out, d_x, d_shift, d_scale, seq, hidden, eps); + return cudaGetLastError() == cudaSuccess; +} + +int flux_cuda_split_qkv_mlp_device(const float *d_fused, + float *d_q, float *d_k, float *d_v, + float *d_gate, float *d_up, + int seq, int hidden, int mlp_hidden) { + if (!d_fused || !d_q || !d_k || !d_v || !d_gate || !d_up) return 0; + if (seq <= 0 || hidden <= 0 || mlp_hidden <= 0) return 0; + if (!flux_cuda_ensure_init()) return 0; + + int threads = 256; + split_qkv_mlp_kernel<<>>(d_fused, d_q, d_k, d_v, d_gate, d_up, + seq, hidden, mlp_hidden); + return cudaGetLastError() == cudaSuccess; +} + +int flux_cuda_qk_rms_norm_device(float *d_q, float *d_k, + const float *d_q_weight, const float *d_k_weight, + int seq, int heads, int head_dim, float eps) { + if (!d_q || !d_k || !d_q_weight || !d_k_weight) return 0; + if (seq <= 0 || heads <= 0 || head_dim <= 0) return 0; + if (!flux_cuda_ensure_init()) return 0; + + int rows = seq * heads; + int threads = 128; + size_t shmem = (size_t)threads * 2 * sizeof(float); + qk_rms_norm_kernel<<>>(d_q, d_k, d_q_weight, d_k_weight, + rows, head_dim, eps); + return cudaGetLastError() == cudaSuccess; +} + +int flux_cuda_rope_unified_device(float *d_q, float *d_k, + const float *d_txt_cos, const float *d_txt_sin, + const float *d_img_cos, const float *d_img_sin, + int seq, int img_offset, int heads, int head_dim) { + if (!d_q || !d_k || !d_txt_cos || !d_txt_sin || !d_img_cos || !d_img_sin) return 0; + if (seq <= 0 || img_offset < 0 || img_offset > seq || heads <= 0 || head_dim <= 0) return 0; + if (!flux_cuda_ensure_init()) return 0; + + int pairs = head_dim / 2; + int total = seq * heads * pairs; + int threads = 256; + int blocks = (total + threads - 1) / threads; + rope_unified_kernel<<>>(d_q, d_k, + d_txt_cos, d_txt_sin, + d_img_cos, d_img_sin, + seq, img_offset, heads, head_dim); + return cudaGetLastError() == cudaSuccess; +} + +int flux_cuda_silu_mul_device(float *d_gate, const float *d_up, int n) { + if (!d_gate || !d_up || n <= 0) return 0; + if (!flux_cuda_ensure_init()) return 0; + + int threads = 256; + int blocks = (n + threads - 1) / threads; + silu_mul_kernel<<>>(d_gate, d_up, n); + return cudaGetLastError() == cudaSuccess; +} + +int flux_cuda_concat_attn_mlp_device(const float *d_attn, const float *d_mlp, + float *d_out, int seq, int hidden, int mlp_hidden) { + if (!d_attn || !d_mlp || !d_out) return 0; + if (seq <= 0 || hidden <= 0 || mlp_hidden <= 0) return 0; + if (!flux_cuda_ensure_init()) return 0; + + int total = seq * (hidden + mlp_hidden); + int threads = 256; + int blocks = (total + threads - 1) / threads; + concat_attn_mlp_kernel<<>>(d_attn, d_mlp, d_out, seq, hidden, mlp_hidden); + return cudaGetLastError() == cudaSuccess; +} + +int flux_cuda_gated_add_device(float *d_hidden, const float *d_gate, + const float *d_proj, int seq, int hidden) { + if (!d_hidden || !d_gate || !d_proj) return 0; + if (seq <= 0 || hidden <= 0) return 0; + if (!flux_cuda_ensure_init()) return 0; + + int total = seq * hidden; + int threads = 256; + int blocks = (total + threads - 1) / threads; + gated_add_kernel<<>>(d_hidden, d_gate, d_proj, seq, hidden); + return cudaGetLastError() == cudaSuccess; +} + +int flux_cuda_concat_seq_device(float *d_out, const float *d_a, const float *d_b, + int seq_a, int seq_b, int hidden) { + if (!d_out || !d_a || !d_b) return 0; + if (seq_a < 0 || seq_b < 0 || hidden <= 0) return 0; + if (!flux_cuda_ensure_init()) return 0; + + int total = (seq_a + seq_b) * hidden; + int threads = 256; + int blocks = (total + threads - 1) / threads; + concat_seq_kernel<<>>(d_out, d_a, d_b, seq_a, seq_b, hidden); + return cudaGetLastError() == cudaSuccess; +} + +int flux_cuda_im2col_nchw_rows_device(float *d_col, const float *d_in, + int in_ch, int H, int W, + int kH, int kW, int stride, int padding, + int outH, int outW, + int row_offset, int tile_h) { + if (!d_col || !d_in) return 0; + if (in_ch <= 0 || H <= 0 || W <= 0 || kH <= 0 || kW <= 0 || stride <= 0) return 0; + if (outH <= 0 || outW <= 0 || row_offset < 0 || tile_h <= 0 || row_offset + tile_h > outH) return 0; + if (!flux_cuda_ensure_init()) return 0; + + int K = in_ch * kH * kW; + int tile_pixels = tile_h * outW; + int total = tile_pixels * K; + int threads = 256; + int blocks = (total + threads - 1) / threads; + im2col_nchw_rows_kernel<<>>(d_col, d_in, + in_ch, H, W, + kH, kW, stride, padding, + outW, row_offset, tile_h, K); + return cudaGetLastError() == cudaSuccess; +} + +int flux_cuda_add_bias_rows_device(float *d_rows, const float *d_bias, + int rows, int cols) { + if (!d_rows || !d_bias) return 0; + if (rows <= 0 || cols <= 0) return 0; + if (!flux_cuda_ensure_init()) return 0; + + int total = rows * cols; + int threads = 256; + int blocks = (total + threads - 1) / threads; + add_bias_rows_kernel<<>>(d_rows, d_bias, rows, cols); + return cudaGetLastError() == cudaSuccess; +} + +int flux_cuda_rows_to_nchw_tile_device(float *d_out, const float *d_rows, + int out_ch, int outH, int outW, + int row_offset, int tile_h) { + if (!d_out || !d_rows) return 0; + if (out_ch <= 0 || outH <= 0 || outW <= 0 || row_offset < 0 || tile_h <= 0) return 0; + if (row_offset + tile_h > outH) return 0; + if (!flux_cuda_ensure_init()) return 0; + + int tile_pixels = tile_h * outW; + int total = out_ch * tile_pixels; + int threads = 256; + int blocks = (total + threads - 1) / threads; + rows_to_nchw_tile_kernel<<>>(d_out, d_rows, out_ch, outH, outW, + row_offset, tile_h); + return cudaGetLastError() == cudaSuccess; +} + +int flux_cuda_nchw_to_rows_device(float *d_rows, const float *d_x, + int channels, int H, int W) { + if (!d_rows || !d_x) return 0; + if (channels <= 0 || H <= 0 || W <= 0) return 0; + if (!flux_cuda_ensure_init()) return 0; + + int total = channels * H * W; + int threads = 256; + int blocks = (total + threads - 1) / threads; + nchw_to_rows_kernel<<>>(d_rows, d_x, channels, H, W); + return cudaGetLastError() == cudaSuccess; +} + +int flux_cuda_rows_to_nchw_device(float *d_x, const float *d_rows, + int channels, int H, int W) { + if (!d_x || !d_rows) return 0; + if (channels <= 0 || H <= 0 || W <= 0) return 0; + if (!flux_cuda_ensure_init()) return 0; + + int total = channels * H * W; + int threads = 256; + int blocks = (total + threads - 1) / threads; + rows_to_nchw_kernel<<>>(d_x, d_rows, channels, H, W); + return cudaGetLastError() == cudaSuccess; +} + +int flux_cuda_group_norm_nchw_device(float *d_out, const float *d_x, + const float *d_gamma, const float *d_beta, + int batch, int channels, int H, int W, + int num_groups, float eps) { + if (!d_out || !d_x || !d_gamma || !d_beta) return 0; + if (batch <= 0 || channels <= 0 || H <= 0 || W <= 0 || num_groups <= 0) return 0; + if (channels % num_groups != 0) return 0; + if (!flux_cuda_ensure_init()) return 0; + + int channels_per_group = channels / num_groups; + int spatial = H * W; + int threads = 256; + int blocks = batch * num_groups; + size_t shmem = (size_t)threads * 2 * sizeof(float); + group_norm_nchw_kernel<<>>(d_out, d_x, d_gamma, d_beta, + batch, channels, spatial, + num_groups, channels_per_group, + eps); + return cudaGetLastError() == cudaSuccess; +} + +int flux_cuda_silu_device(float *d_x, int n) { + if (!d_x || n <= 0) return 0; + if (!flux_cuda_ensure_init()) return 0; + + int threads = 256; + int blocks = (n + threads - 1) / threads; + silu_kernel<<>>(d_x, n); + return cudaGetLastError() == cudaSuccess; +} + +int flux_cuda_add_inplace_device(float *d_a, const float *d_b, int n) { + if (!d_a || !d_b || n <= 0) return 0; + if (!flux_cuda_ensure_init()) return 0; + + int threads = 256; + int blocks = (n + threads - 1) / threads; + add_inplace_kernel<<>>(d_a, d_b, n); + return cudaGetLastError() == cudaSuccess; +} + +int flux_cuda_upsample_nearest2x_nchw_device(float *d_out, const float *d_in, + int channels, int H, int W) { + if (!d_out || !d_in) return 0; + if (channels <= 0 || H <= 0 || W <= 0) return 0; + if (!flux_cuda_ensure_init()) return 0; + + int outH = H * 2; + int outW = W * 2; + int total = channels * outH * outW; + int threads = 256; + int blocks = (total + threads - 1) / threads; + upsample_nearest2x_nchw_kernel<<>>(d_out, d_in, channels, H, W); + return cudaGetLastError() == cudaSuccess; +} diff --git a/flux_cuda.h b/flux_cuda.h new file mode 100644 index 0000000..65ff508 --- /dev/null +++ b/flux_cuda.h @@ -0,0 +1,128 @@ +/* + * FLUX CUDA Attention Helpers + * + * CUDA-only helpers for attention operations that keep intermediate tensors + * on GPU and run softmax/masking directly on device. + */ + +#ifndef FLUX_CUDA_H +#define FLUX_CUDA_H + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * Single-head attention: + * out[seq_q, head_dim] = softmax((q @ k^T) * scale + masks) @ v + * + * q: [seq_q, head_dim] + * k: [seq_k, head_dim] + * v: [seq_k, head_dim] + * + * causal: apply causal mask (j > i masked) + * attention_mask: optional [seq_k] mask (0 = masked, non-zero = valid), can be NULL + * prefer_bf16: try BF16 tensor-core QK matmul via cuBLASLt first + * + * Returns 1 on success, 0 on failure (caller should fall back). + */ +int flux_cuda_attention_single(float *out, + const float *q, const float *k, const float *v, + int seq_q, int seq_k, int head_dim, + float scale, int causal, + const int *attention_mask, + int prefer_bf16); + +/* + * Batched attention for equal-head layouts: + * q/k/v/out are [heads, seq, head_dim] contiguous (head-major). + * + * causal and attention_mask semantics are the same as single-head. + * Returns 1 on success, 0 on failure. + */ +int flux_cuda_attention_batched(float *out, + const float *q, const float *k, const float *v, + int heads, int seq_q, int seq_k, int head_dim, + float scale, int causal, + const int *attention_mask); + +/* + * Batched attention with sequence-major input/output layout: + * q/k/v/out are [seq, heads * head_dim] contiguous (sequence-major). + * + * Internally transposes on GPU, runs the same CUDA batched attention path, + * then transposes back. This avoids CPU transpose overhead in single-block + * transformer paths that already operate in [seq, hidden] layout. + * + * Returns 1 on success, 0 on failure. + */ +int flux_cuda_attention_batched_shd(float *out, + const float *q, const float *k, const float *v, + int heads, int seq_q, int seq_k, int head_dim, + float scale, int causal, + const int *attention_mask); + +/* Same as flux_cuda_attention_batched_shd(), but operates directly on device + * pointers and leaves output on device (no host/device copies). */ +int flux_cuda_attention_batched_shd_device(float *d_out, + const float *d_q, const float *d_k, const float *d_v, + int heads, int seq_q, int seq_k, int head_dim, + float scale, int causal, + const int *attention_mask); + +/* Set CUDA stream used by CUDA attention/op helpers. + * Pass NULL to use the default stream. */ +int flux_cuda_ops_set_stream(void *stream_handle); + +/* CUDA device kernels used by CUDA-resident transformer blocks. */ +int flux_cuda_adaln_norm_device(float *d_out, const float *d_x, + const float *d_shift, const float *d_scale, + int seq, int hidden, float eps); +int flux_cuda_split_qkv_mlp_device(const float *d_fused, + float *d_q, float *d_k, float *d_v, + float *d_gate, float *d_up, + int seq, int hidden, int mlp_hidden); +int flux_cuda_qk_rms_norm_device(float *d_q, float *d_k, + const float *d_q_weight, const float *d_k_weight, + int seq, int heads, int head_dim, float eps); +int flux_cuda_rope_unified_device(float *d_q, float *d_k, + const float *d_txt_cos, const float *d_txt_sin, + const float *d_img_cos, const float *d_img_sin, + int seq, int img_offset, int heads, int head_dim); +int flux_cuda_silu_mul_device(float *d_gate, const float *d_up, int n); +int flux_cuda_concat_attn_mlp_device(const float *d_attn, const float *d_mlp, + float *d_out, int seq, int hidden, int mlp_hidden); +int flux_cuda_gated_add_device(float *d_hidden, const float *d_gate, + const float *d_proj, int seq, int hidden); +int flux_cuda_concat_seq_device(float *d_out, const float *d_a, const float *d_b, + int seq_a, int seq_b, int hidden); + +/* Generic CUDA tensor ops used by CUDA-resident VAE decode path. */ +int flux_cuda_im2col_nchw_rows_device(float *d_col, const float *d_in, + int in_ch, int H, int W, + int kH, int kW, int stride, int padding, + int outH, int outW, + int row_offset, int tile_h); +int flux_cuda_add_bias_rows_device(float *d_rows, const float *d_bias, + int rows, int cols); +int flux_cuda_rows_to_nchw_tile_device(float *d_out, const float *d_rows, + int out_ch, int outH, int outW, + int row_offset, int tile_h); +int flux_cuda_nchw_to_rows_device(float *d_rows, const float *d_x, + int channels, int H, int W); +int flux_cuda_rows_to_nchw_device(float *d_x, const float *d_rows, + int channels, int H, int W); +int flux_cuda_group_norm_nchw_device(float *d_out, const float *d_x, + const float *d_gamma, const float *d_beta, + int batch, int channels, int H, int W, + int num_groups, float eps); +int flux_cuda_silu_device(float *d_x, int n); +int flux_cuda_add_inplace_device(float *d_a, const float *d_b, int n); +int flux_cuda_upsample_nearest2x_nchw_device(float *d_out, const float *d_in, + int channels, int H, int W); + +#ifdef __cplusplus +} +#endif + +#endif /* FLUX_CUDA_H */ diff --git a/flux_kernels.c b/flux_kernels.c index c56ab7f..7ec92e1 100644 --- a/flux_kernels.c +++ b/flux_kernels.c @@ -10,6 +10,7 @@ #include #include #include +#include /* Use Metal for GPU acceleration on Apple Silicon */ #ifdef USE_METAL @@ -25,10 +26,796 @@ #endif #endif -/* Minimum matrix size to use GPU (smaller matrices are faster on CPU) */ +#ifdef USE_CUDA +#include +#include +#endif + +/* Minimum matrix size to use Metal GPU (smaller matrices are faster on CPU) */ #define MIN_GPU_ELEMENTS (512 * 512) -/* fast_expf is defined in flux_kernels.h */ +#ifdef USE_CUDA +/* Default minimum GEMM op-count (M*K*N) to dispatch to CUDA. + * Tunable with FLUX_CUDA_MIN_OPS env var. + * Lowered after CUDA-attention batching benchmarks (2M was consistently faster + * than 8M-20M on FLUX.2-klein 256x256 runs). */ +#define MIN_CUDA_OPS_DEFAULT ((size_t)2 * 1024 * 1024) +#define CUDA_WEIGHT_CACHE_DEFAULT_MB 1024ULL +#define CUDA_WEIGHT_CACHE_MAX_ENTRIES 256 +#define CUDA_WEIGHT_CACHE_MIN_BYTES ((size_t)1 * 1024 * 1024) + +static cublasHandle_t g_cuda_handle = NULL; +static cudaStream_t g_cuda_stream = NULL; +static float *g_cuda_a = NULL; +static float *g_cuda_b = NULL; +static float *g_cuda_c = NULL; +static uint16_t *g_cuda_b_bf16 = NULL; +static size_t g_cuda_a_bytes = 0; +static size_t g_cuda_b_bytes = 0; +static size_t g_cuda_c_bytes = 0; +static size_t g_cuda_b_bf16_bytes = 0; +static int g_cuda_available = -1; +static int g_cuda_warned = 0; +static size_t g_cuda_min_ops = 0; +static int g_cuda_bf16_linear_ok = -1; +static int g_cuda_tf32_mode = -1; +static int g_cuda_weight_cache_configured = 0; +static size_t g_cuda_weight_cache_limit_bytes = 0; +static size_t g_cuda_weight_cache_used_bytes = 0; +static uint64_t g_cuda_weight_cache_tick = 1; + +typedef struct { + const void *host_ptr; + void *device_ptr; + size_t bytes; + int k; + int n; + int transpose_b; + int data_type; /* 0=f32, 1=bf16 */ + uint64_t fingerprint; + uint64_t last_use; +} cuda_weight_cache_entry_t; + +static cuda_weight_cache_entry_t g_cuda_weight_cache[CUDA_WEIGHT_CACHE_MAX_ENTRIES]; + +static void flux_cuda_warn_once(const char *msg) { + if (!g_cuda_warned) { + fprintf(stderr, "%s\n", msg); + g_cuda_warned = 1; + } +} + +static size_t flux_cuda_get_min_ops(void) { + if (g_cuda_min_ops != 0) { + return g_cuda_min_ops; + } + + g_cuda_min_ops = MIN_CUDA_OPS_DEFAULT; + const char *env = getenv("FLUX_CUDA_MIN_OPS"); + if (!env || !*env) { + return g_cuda_min_ops; + } + + errno = 0; + char *end = NULL; + unsigned long long parsed = strtoull(env, &end, 10); + if (errno == 0 && end && *end == '\0' && parsed > 0ULL) { + g_cuda_min_ops = (size_t)parsed; + } else { + flux_cuda_warn_once("CUDA: invalid FLUX_CUDA_MIN_OPS value, using default threshold"); + } + + return g_cuda_min_ops; +} + +static int flux_cuda_tf32_enabled(void) { + if (g_cuda_tf32_mode != -1) { + return g_cuda_tf32_mode; + } + g_cuda_tf32_mode = getenv("FLUX_CUDA_NO_TF32") ? 0 : 1; + return g_cuda_tf32_mode; +} + +static uint64_t flux_cuda_weight_fingerprint(const void *data, size_t bytes) { + if (!data || bytes == 0) { + return 0; + } + + const uint8_t *p = (const uint8_t *)data; + uint64_t h = 1469598103934665603ULL ^ (uint64_t)bytes; + const int samples = 16; + + for (int i = 0; i < samples; i++) { + size_t idx = (size_t)i * (bytes - 1) / (size_t)(samples - 1); + h ^= (uint64_t)p[idx]; + h *= 1099511628211ULL; + } + return h; +} + +static uint64_t flux_cuda_next_cache_tick(void) { + g_cuda_weight_cache_tick++; + if (g_cuda_weight_cache_tick == 0) { + g_cuda_weight_cache_tick = 1; + } + return g_cuda_weight_cache_tick; +} + +static void flux_cuda_weight_cache_clear_entry(int idx) { + cuda_weight_cache_entry_t *e = &g_cuda_weight_cache[idx]; + if (e->device_ptr) { + cudaFree(e->device_ptr); + } + if (g_cuda_weight_cache_used_bytes >= e->bytes) { + g_cuda_weight_cache_used_bytes -= e->bytes; + } else { + g_cuda_weight_cache_used_bytes = 0; + } + memset(e, 0, sizeof(*e)); +} + +static void flux_cuda_weight_cache_clear_all(void) { + for (int i = 0; i < CUDA_WEIGHT_CACHE_MAX_ENTRIES; i++) { + flux_cuda_weight_cache_clear_entry(i); + } + g_cuda_weight_cache_used_bytes = 0; +} + +static int flux_cuda_weight_cache_find_lru(void) { + int lru_idx = -1; + uint64_t oldest = UINT64_MAX; + + for (int i = 0; i < CUDA_WEIGHT_CACHE_MAX_ENTRIES; i++) { + cuda_weight_cache_entry_t *e = &g_cuda_weight_cache[i]; + if (!e->device_ptr) continue; + if (e->last_use < oldest) { + oldest = e->last_use; + lru_idx = i; + } + } + + return lru_idx; +} + +static int flux_cuda_weight_cache_ensure_space(size_t need_bytes) { + if (need_bytes > g_cuda_weight_cache_limit_bytes) { + return 0; + } + + while (g_cuda_weight_cache_used_bytes + need_bytes > g_cuda_weight_cache_limit_bytes) { + int lru_idx = flux_cuda_weight_cache_find_lru(); + if (lru_idx < 0) { + return 0; + } + flux_cuda_weight_cache_clear_entry(lru_idx); + } + + return 1; +} + +static int flux_cuda_weight_cache_find(const void *host_ptr, + int k, int n, int transpose_b, + int data_type, uint64_t fingerprint) { + for (int i = 0; i < CUDA_WEIGHT_CACHE_MAX_ENTRIES; i++) { + cuda_weight_cache_entry_t *e = &g_cuda_weight_cache[i]; + if (!e->device_ptr) continue; + if (e->host_ptr == host_ptr && + e->k == k && + e->n == n && + e->transpose_b == transpose_b && + e->data_type == data_type) { + if (e->fingerprint == fingerprint) { + e->last_use = flux_cuda_next_cache_tick(); + return i; + } + /* Same address reused for different content. Evict stale entry. */ + flux_cuda_weight_cache_clear_entry(i); + return -1; + } + } + return -1; +} + +static int flux_cuda_weight_cache_select_slot(void) { + for (int i = 0; i < CUDA_WEIGHT_CACHE_MAX_ENTRIES; i++) { + if (!g_cuda_weight_cache[i].device_ptr) { + return i; + } + } + + int lru_idx = flux_cuda_weight_cache_find_lru(); + if (lru_idx >= 0) { + flux_cuda_weight_cache_clear_entry(lru_idx); + } + return lru_idx; +} + +static void *flux_cuda_weight_cache_insert(const void *host_ptr, + int k, int n, int transpose_b, int data_type, + uint64_t fingerprint, + size_t bytes) { + if (bytes > g_cuda_weight_cache_limit_bytes || g_cuda_weight_cache_limit_bytes == 0) { + return NULL; + } + + if (!flux_cuda_weight_cache_ensure_space(bytes)) { + return NULL; + } + + int slot = flux_cuda_weight_cache_select_slot(); + if (slot < 0) { + return NULL; + } + + void *device_ptr = NULL; + if (cudaMalloc(&device_ptr, bytes) != cudaSuccess) { + return NULL; + } + if (cudaMemcpy(device_ptr, host_ptr, bytes, cudaMemcpyHostToDevice) != cudaSuccess) { + cudaFree(device_ptr); + return NULL; + } + + cuda_weight_cache_entry_t *e = &g_cuda_weight_cache[slot]; + e->host_ptr = host_ptr; + e->device_ptr = device_ptr; + e->bytes = bytes; + e->k = k; + e->n = n; + e->transpose_b = transpose_b; + e->data_type = data_type; + e->fingerprint = fingerprint; + e->last_use = flux_cuda_next_cache_tick(); + g_cuda_weight_cache_used_bytes += bytes; + return device_ptr; +} + +static void flux_cuda_weight_cache_configure(void) { + if (g_cuda_weight_cache_configured) { + return; + } + g_cuda_weight_cache_configured = 1; + + size_t limit_bytes = (size_t)(CUDA_WEIGHT_CACHE_DEFAULT_MB * 1024ULL * 1024ULL); + const char *env = getenv("FLUX_CUDA_WEIGHT_CACHE_MB"); + + if (env && *env) { + errno = 0; + char *end = NULL; + unsigned long long parsed = strtoull(env, &end, 10); + if (errno == 0 && end && *end == '\0') { + if (parsed == 0ULL) { + g_cuda_weight_cache_limit_bytes = 0; + return; + } + if (parsed > (unsigned long long)(SIZE_MAX / (1024ULL * 1024ULL))) { + limit_bytes = SIZE_MAX; + } else { + limit_bytes = (size_t)parsed * 1024ULL * 1024ULL; + } + } else { + flux_cuda_warn_once("CUDA: invalid FLUX_CUDA_WEIGHT_CACHE_MB value, using default cache size"); + } + } else { + size_t free_mem = 0, total_mem = 0; + if (cudaMemGetInfo(&free_mem, &total_mem) == cudaSuccess && free_mem > 0) { + size_t auto_limit = free_mem / 4; /* Keep cache under 25% of currently free VRAM. */ + size_t max_default = (size_t)(CUDA_WEIGHT_CACHE_DEFAULT_MB * 1024ULL * 1024ULL); + if (auto_limit < CUDA_WEIGHT_CACHE_MIN_BYTES) auto_limit = CUDA_WEIGHT_CACHE_MIN_BYTES; + if (auto_limit < max_default) limit_bytes = auto_limit; + } + } + + g_cuda_weight_cache_limit_bytes = limit_bytes; +} + +static void flux_cuda_cleanup_impl(void) { + flux_cuda_weight_cache_clear_all(); + + if (g_cuda_a) cudaFree(g_cuda_a); + if (g_cuda_b) cudaFree(g_cuda_b); + if (g_cuda_c) cudaFree(g_cuda_c); + if (g_cuda_b_bf16) cudaFree(g_cuda_b_bf16); + g_cuda_a = g_cuda_b = g_cuda_c = NULL; + g_cuda_b_bf16 = NULL; + g_cuda_a_bytes = g_cuda_b_bytes = g_cuda_c_bytes = 0; + g_cuda_b_bf16_bytes = 0; + + if (g_cuda_handle) { + cublasDestroy(g_cuda_handle); + g_cuda_handle = NULL; + } +} + +static int flux_cuda_ensure_initialized(void) { + if (g_cuda_available != -1) { + return g_cuda_available; + } + + g_cuda_available = 0; + + int device_count = 0; + cudaError_t cuda_err = cudaGetDeviceCount(&device_count); + if (cuda_err != cudaSuccess || device_count <= 0) { + return 0; + } + + if (cublasCreate(&g_cuda_handle) != CUBLAS_STATUS_SUCCESS) { + flux_cuda_warn_once("CUDA: failed to initialize cuBLAS, falling back to CPU/BLAS"); + flux_cuda_cleanup_impl(); + return 0; + } + + if (cublasSetPointerMode(g_cuda_handle, CUBLAS_POINTER_MODE_HOST) != CUBLAS_STATUS_SUCCESS) { + flux_cuda_warn_once("CUDA: failed to configure cuBLAS pointer mode, falling back to CPU/BLAS"); + flux_cuda_cleanup_impl(); + return 0; + } + if (cublasSetStream(g_cuda_handle, g_cuda_stream) != CUBLAS_STATUS_SUCCESS) { + flux_cuda_warn_once("CUDA: failed to configure cuBLAS stream, falling back to CPU/BLAS"); + flux_cuda_cleanup_impl(); + return 0; + } + + flux_cuda_weight_cache_configure(); + atexit(flux_cuda_cleanup_impl); + g_cuda_available = 1; + return 1; +} + +static int flux_cuda_ensure_buffer(float **buffer, size_t *capacity_bytes, size_t needed_bytes) { + if (*capacity_bytes >= needed_bytes) { + return 1; + } + + size_t new_capacity = needed_bytes; + if (*capacity_bytes > 0) { + new_capacity = *capacity_bytes; + while (new_capacity < needed_bytes) { + new_capacity *= 2; + } + } + + float *new_buffer = NULL; + if (cudaMalloc((void **)&new_buffer, new_capacity) != cudaSuccess) { + flux_cuda_warn_once("CUDA: device allocation failed, falling back to CPU/BLAS"); + return 0; + } + + if (*buffer) { + cudaFree(*buffer); + } + *buffer = new_buffer; + *capacity_bytes = new_capacity; + return 1; +} + +int flux_cuda_linear_set_stream(void *stream_handle) { +#ifdef USE_CUDA + g_cuda_stream = (cudaStream_t)stream_handle; + if (!flux_cuda_ensure_initialized()) return 0; + return cublasSetStream(g_cuda_handle, g_cuda_stream) == CUBLAS_STATUS_SUCCESS; +#else + (void)stream_handle; + return 0; +#endif +} + +static int flux_cuda_ensure_buffer_bf16(uint16_t **buffer, + size_t *capacity_bytes, + size_t needed_bytes) { + if (*capacity_bytes >= needed_bytes) { + return 1; + } + + size_t new_capacity = needed_bytes; + if (*capacity_bytes > 0) { + new_capacity = *capacity_bytes; + while (new_capacity < needed_bytes) { + new_capacity *= 2; + } + } + + uint16_t *new_buffer = NULL; + if (cudaMalloc((void **)&new_buffer, new_capacity) != cudaSuccess) { + flux_cuda_warn_once("CUDA: bf16 device allocation failed, falling back to CPU/BLAS"); + return 0; + } + + if (*buffer) { + cudaFree(*buffer); + } + *buffer = new_buffer; + *capacity_bytes = new_capacity; + return 1; +} + +static int flux_cuda_probe_bf16_linear(void) { + if (!flux_cuda_ensure_initialized()) { + return 0; + } + + float h_a[16]; + uint16_t h_b[16]; + float h_c[16]; + for (int i = 0; i < 16; i++) { + float a = 0.01f * (float)(i + 1); + h_a[i] = a; + uint32_t bits; + memcpy(&bits, &a, sizeof(bits)); + h_b[i] = (uint16_t)(bits >> 16); /* f32 -> bf16 */ + h_c[i] = 0.0f; + } + + float *d_a = NULL; + uint16_t *d_b = NULL; + float *d_c = NULL; + size_t a_bytes = sizeof(h_a); + size_t b_bytes = sizeof(h_b); + size_t c_bytes = sizeof(h_c); + + if (cudaMalloc((void **)&d_a, a_bytes) != cudaSuccess || + cudaMalloc((void **)&d_b, b_bytes) != cudaSuccess || + cudaMalloc((void **)&d_c, c_bytes) != cudaSuccess) { + if (d_a) cudaFree(d_a); + if (d_b) cudaFree(d_b); + if (d_c) cudaFree(d_c); + return 0; + } + + int ok = 1; + if (cudaMemcpy(d_a, h_a, a_bytes, cudaMemcpyHostToDevice) != cudaSuccess || + cudaMemcpy(d_b, h_b, b_bytes, cudaMemcpyHostToDevice) != cudaSuccess) { + ok = 0; + goto done; + } + + const float alpha = 1.0f; + const float beta = 0.0f; +#ifdef CUBLAS_COMPUTE_32F_FAST_16BF + cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_16BF; +#else + cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; +#endif + cublasStatus_t st = cublasGemmEx(g_cuda_handle, + CUBLAS_OP_T, CUBLAS_OP_N, + 4, 4, 4, + &alpha, + d_b, CUDA_R_16BF, 4, + d_a, CUDA_R_32F, 4, + &beta, + d_c, CUDA_R_32F, 4, + compute_type, CUBLAS_GEMM_DEFAULT); + if (st != CUBLAS_STATUS_SUCCESS) { + ok = 0; + goto done; + } + + if (cudaMemcpy(h_c, d_c, c_bytes, cudaMemcpyDeviceToHost) != cudaSuccess) { + ok = 0; + } + +done: + cudaFree(d_a); + cudaFree(d_b); + cudaFree(d_c); + return ok; +} + +/* Row-major SGEMM wrapper. + * Computes C = A @ B (transpose_b=0) or C = A @ B^T (transpose_b=1). */ +static int flux_cuda_sgemm_rowmajor(float *C, const float *A, const float *B, + int M, int K, int N, int transpose_b, + int cache_b) { + if (!flux_cuda_ensure_initialized()) { + return 0; + } + + size_t a_bytes = (size_t)M * K * sizeof(float); + size_t b_bytes = transpose_b + ? (size_t)N * K * sizeof(float) + : (size_t)K * N * sizeof(float); + size_t c_bytes = (size_t)M * N * sizeof(float); + + if (!flux_cuda_ensure_buffer(&g_cuda_a, &g_cuda_a_bytes, a_bytes) || + !flux_cuda_ensure_buffer(&g_cuda_c, &g_cuda_c_bytes, c_bytes)) { + return 0; + } + + cudaError_t cuda_err = cudaMemcpy(g_cuda_a, A, a_bytes, cudaMemcpyHostToDevice); + if (cuda_err != cudaSuccess) { + flux_cuda_warn_once("CUDA: host-to-device copy failed, falling back to CPU/BLAS"); + return 0; + } + + const float *d_b = NULL; + if (cache_b && g_cuda_weight_cache_limit_bytes > 0 && b_bytes >= CUDA_WEIGHT_CACHE_MIN_BYTES) { + uint64_t fingerprint = flux_cuda_weight_fingerprint(B, b_bytes); + int cache_idx = flux_cuda_weight_cache_find(B, K, N, transpose_b, 0, fingerprint); + if (cache_idx >= 0) { + d_b = (const float *)g_cuda_weight_cache[cache_idx].device_ptr; + } else { + d_b = (const float *)flux_cuda_weight_cache_insert(B, K, N, transpose_b, + 0, fingerprint, b_bytes); + } + } + + if (!d_b) { + if (!flux_cuda_ensure_buffer(&g_cuda_b, &g_cuda_b_bytes, b_bytes)) { + return 0; + } + cuda_err = cudaMemcpy(g_cuda_b, B, b_bytes, cudaMemcpyHostToDevice); + if (cuda_err != cudaSuccess) { + flux_cuda_warn_once("CUDA: host-to-device copy failed, falling back to CPU/BLAS"); + return 0; + } + d_b = g_cuda_b; + } + + const float alpha = 1.0f; + const float beta = 0.0f; + cublasStatus_t status; + cublasOperation_t op_a = transpose_b ? CUBLAS_OP_T : CUBLAS_OP_N; + int lda = transpose_b ? K : N; + + /* Prefer TF32 tensor-core GEMM on supported GPUs. + * Disable with FLUX_CUDA_NO_TF32=1 for exact-fp32 fallback testing. */ + if (flux_cuda_tf32_enabled()) { +#ifdef CUBLAS_COMPUTE_32F_FAST_TF32 + cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; +#else + cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; +#endif + status = cublasGemmEx(g_cuda_handle, + op_a, CUBLAS_OP_N, + N, M, K, + &alpha, + d_b, CUDA_R_32F, lda, + g_cuda_a, CUDA_R_32F, K, + &beta, + g_cuda_c, CUDA_R_32F, N, + compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP); + } else { + status = CUBLAS_STATUS_NOT_SUPPORTED; + } + + if (status != CUBLAS_STATUS_SUCCESS) { + /* Row-major A[M,K] @ B[*,K] -> column-major mapping for cuBLAS */ + status = cublasSgemm(g_cuda_handle, + op_a, CUBLAS_OP_N, + N, M, K, + &alpha, + d_b, lda, + g_cuda_a, K, + &beta, + g_cuda_c, N); + } + + if (status != CUBLAS_STATUS_SUCCESS) { + flux_cuda_warn_once("CUDA: cuBLAS SGEMM failed, falling back to CPU/BLAS"); + return 0; + } + + cuda_err = cudaMemcpy(C, g_cuda_c, c_bytes, cudaMemcpyDeviceToHost); + if (cuda_err != cudaSuccess) { + flux_cuda_warn_once("CUDA: device-to-host copy failed, falling back to CPU/BLAS"); + return 0; + } + + return 1; +} + +/* Row-major SGEMM wrapper with device input/output and host-side weight matrix. + * d_C = d_A @ B (transpose_b=0) or d_A @ B^T (transpose_b=1) + * d_A/d_C are device pointers, B is host pointer (cached on device when possible). */ +static int flux_cuda_sgemm_rowmajor_device_weight(float *d_C, + const float *d_A, + const float *B, + int M, int K, int N, + int transpose_b, + int cache_b) { + if (!flux_cuda_ensure_initialized()) { + return 0; + } + if (!d_C || !d_A || !B || M <= 0 || K <= 0 || N <= 0) { + return 0; + } + + size_t b_bytes = transpose_b + ? (size_t)N * K * sizeof(float) + : (size_t)K * N * sizeof(float); + + const float *d_b = NULL; + if (cache_b && g_cuda_weight_cache_limit_bytes > 0 && b_bytes >= CUDA_WEIGHT_CACHE_MIN_BYTES) { + uint64_t fingerprint = flux_cuda_weight_fingerprint(B, b_bytes); + int cache_idx = flux_cuda_weight_cache_find(B, K, N, transpose_b, 0, fingerprint); + if (cache_idx >= 0) { + d_b = (const float *)g_cuda_weight_cache[cache_idx].device_ptr; + } else { + d_b = (const float *)flux_cuda_weight_cache_insert(B, K, N, transpose_b, + 0, fingerprint, b_bytes); + } + } + + if (!d_b) { + if (!flux_cuda_ensure_buffer(&g_cuda_b, &g_cuda_b_bytes, b_bytes)) { + return 0; + } + if (cudaMemcpy(g_cuda_b, B, b_bytes, cudaMemcpyHostToDevice) != cudaSuccess) { + flux_cuda_warn_once("CUDA: host-to-device copy failed, falling back to CPU/BLAS"); + return 0; + } + d_b = g_cuda_b; + } + + const float alpha = 1.0f; + const float beta = 0.0f; + cublasOperation_t op_a = transpose_b ? CUBLAS_OP_T : CUBLAS_OP_N; + int lda = transpose_b ? K : N; + + cublasStatus_t status = CUBLAS_STATUS_NOT_SUPPORTED; + if (flux_cuda_tf32_enabled()) { +#ifdef CUBLAS_COMPUTE_32F_FAST_TF32 + cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; +#else + cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; +#endif + status = cublasGemmEx(g_cuda_handle, + op_a, CUBLAS_OP_N, + N, M, K, + &alpha, + d_b, CUDA_R_32F, lda, + d_A, CUDA_R_32F, K, + &beta, + d_C, CUDA_R_32F, N, + compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP); + } + + if (status != CUBLAS_STATUS_SUCCESS) { + status = cublasSgemm(g_cuda_handle, + op_a, CUBLAS_OP_N, + N, M, K, + &alpha, + d_b, lda, + d_A, K, + &beta, + d_C, N); + } + + if (status != CUBLAS_STATUS_SUCCESS) { + flux_cuda_warn_once("CUDA: cuBLAS SGEMM failed, falling back to CPU/BLAS"); + return 0; + } + + return 1; +} + +/* Row-major SGEMM wrapper with bf16 weights: + * C = A @ B (transpose_b=0) or C = A @ B^T (transpose_b=1) + * A/C are f32, B is bf16 (uint16 storage). */ +static int flux_cuda_sgemm_rowmajor_bf16_weight(float *C, + const float *A, + const uint16_t *B_bf16, + int M, int K, int N, + int transpose_b, + int cache_b) { + if (!flux_cuda_ensure_initialized()) { + return 0; + } + + size_t a_bytes = (size_t)M * K * sizeof(float); + size_t b_bytes = transpose_b + ? (size_t)N * K * sizeof(uint16_t) + : (size_t)K * N * sizeof(uint16_t); + size_t c_bytes = (size_t)M * N * sizeof(float); + + if (!flux_cuda_ensure_buffer(&g_cuda_a, &g_cuda_a_bytes, a_bytes) || + !flux_cuda_ensure_buffer(&g_cuda_c, &g_cuda_c_bytes, c_bytes)) { + return 0; + } + + cudaError_t cuda_err = cudaMemcpy(g_cuda_a, A, a_bytes, cudaMemcpyHostToDevice); + if (cuda_err != cudaSuccess) { + flux_cuda_warn_once("CUDA: bf16 host-to-device copy failed, falling back to CPU/BLAS"); + return 0; + } + + const uint16_t *d_b = NULL; + if (cache_b && g_cuda_weight_cache_limit_bytes > 0 && b_bytes >= CUDA_WEIGHT_CACHE_MIN_BYTES) { + uint64_t fingerprint = flux_cuda_weight_fingerprint(B_bf16, b_bytes); + int cache_idx = flux_cuda_weight_cache_find(B_bf16, K, N, transpose_b, 1, fingerprint); + if (cache_idx >= 0) { + d_b = (const uint16_t *)g_cuda_weight_cache[cache_idx].device_ptr; + } else { + d_b = (const uint16_t *)flux_cuda_weight_cache_insert(B_bf16, K, N, transpose_b, + 1, fingerprint, b_bytes); + } + } + + if (!d_b) { + if (!flux_cuda_ensure_buffer_bf16(&g_cuda_b_bf16, &g_cuda_b_bf16_bytes, b_bytes)) { + return 0; + } + cuda_err = cudaMemcpy(g_cuda_b_bf16, B_bf16, b_bytes, cudaMemcpyHostToDevice); + if (cuda_err != cudaSuccess) { + flux_cuda_warn_once("CUDA: bf16 host-to-device copy failed, falling back to CPU/BLAS"); + return 0; + } + d_b = g_cuda_b_bf16; + } + + const float alpha = 1.0f; + const float beta = 0.0f; +#ifdef CUBLAS_COMPUTE_32F_FAST_16BF + cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_16BF; +#else + cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; +#endif + + cublasStatus_t status; + if (transpose_b) { + status = cublasGemmEx(g_cuda_handle, + CUBLAS_OP_T, CUBLAS_OP_N, + N, M, K, + &alpha, + d_b, CUDA_R_16BF, K, + g_cuda_a, CUDA_R_32F, K, + &beta, + g_cuda_c, CUDA_R_32F, N, + compute_type, CUBLAS_GEMM_DEFAULT); + } else { + status = cublasGemmEx(g_cuda_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + N, M, K, + &alpha, + d_b, CUDA_R_16BF, N, + g_cuda_a, CUDA_R_32F, K, + &beta, + g_cuda_c, CUDA_R_32F, N, + compute_type, CUBLAS_GEMM_DEFAULT); + } + + if (status != CUBLAS_STATUS_SUCCESS) { + flux_cuda_warn_once("CUDA: bf16 cuBLAS GEMM failed, falling back to CPU/BLAS"); + return 0; + } + + cuda_err = cudaMemcpy(C, g_cuda_c, c_bytes, cudaMemcpyDeviceToHost); + if (cuda_err != cudaSuccess) { + flux_cuda_warn_once("CUDA: bf16 device-to-host copy failed, falling back to CPU/BLAS"); + return 0; + } + + return 1; +} + +int flux_cuda_bf16_linear_available(void) { + if (g_cuda_bf16_linear_ok != -1) { + return g_cuda_bf16_linear_ok; + } + + g_cuda_bf16_linear_ok = flux_cuda_probe_bf16_linear() ? 1 : 0; + return g_cuda_bf16_linear_ok; +} +#endif + +#ifndef USE_CUDA +int flux_cuda_bf16_linear_available(void) { + return 0; +} +#endif + +int flux_cuda_linear_nobias_device(float *d_y, const float *d_x, const float *W, + int seq_len, int in_dim, int out_dim) { +#ifdef USE_CUDA + return flux_cuda_sgemm_rowmajor_device_weight(d_y, d_x, W, + seq_len, in_dim, out_dim, + 1, 1); +#else + (void)d_y; (void)d_x; (void)W; (void)seq_len; (void)in_dim; (void)out_dim; + return 0; +#endif +} /* Progress callbacks - set by caller before inference */ flux_substep_callback_t flux_substep_callback = NULL; @@ -153,6 +940,14 @@ void flux_matmul(float *C, const float *A, const float *B, } #endif +#ifdef USE_CUDA + size_t matrix_ops_cuda = (size_t)M * K * N; + if (matrix_ops_cuda >= flux_cuda_get_min_ops() && + flux_cuda_sgemm_rowmajor(C, A, B, M, K, N, 0, 0)) { + return; + } +#endif + #ifdef USE_BLAS cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, @@ -190,6 +985,14 @@ void flux_matmul_t(float *C, const float *A, const float *B, } #endif +#ifdef USE_CUDA + size_t matrix_ops_cuda = (size_t)M * K * N; + if (matrix_ops_cuda >= flux_cuda_get_min_ops() && + flux_cuda_sgemm_rowmajor(C, A, B, M, K, N, 1, 0)) { + return; + } +#endif + #ifdef USE_BLAS cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, M, N, K, @@ -242,6 +1045,21 @@ void flux_linear(float *y, const float *x, const float *W, const float *b, } #endif +#ifdef USE_CUDA + size_t matrix_ops_cuda = (size_t)seq_len * in_dim * out_dim; + if (matrix_ops_cuda >= flux_cuda_get_min_ops() && + flux_cuda_sgemm_rowmajor(y, x, W, seq_len, in_dim, out_dim, 1, 1)) { + if (b != NULL) { + for (int s = 0; s < seq_len; s++) { + for (int o = 0; o < out_dim; o++) { + y[s * out_dim + o] += b[o]; + } + } + } + return; + } +#endif + #ifdef USE_BLAS /* Use BLAS sgemm: C = alpha * A @ B^T + beta * C * A[M, K] = x[seq_len, in_dim] @@ -307,6 +1125,16 @@ void flux_linear_nobias_bf16(float *y, const float *x, const uint16_t *W_bf16, } #endif +#ifdef USE_CUDA + size_t matrix_ops_cuda = (size_t)seq_len * in_dim * out_dim; + if (matrix_ops_cuda >= flux_cuda_get_min_ops() && + flux_cuda_sgemm_rowmajor_bf16_weight(y, x, W_bf16, + seq_len, in_dim, out_dim, + 1, 1)) { + return; + } +#endif + /* Fallback: convert bf16 to f32 and use regular linear */ float *W_f32 = (float *)malloc((size_t)out_dim * in_dim * sizeof(float)); if (!W_f32) return; @@ -404,15 +1232,28 @@ void flux_conv2d(float *out, const float *in, const float *weight, const float * * where K = in_ch * kH * kW */ int K = in_ch * kH * kW; - /* Write sgemm output directly to out_b using strided ldc. - * Row oc of sgemm output goes to out_b[oc * outH*outW + tile_start*outW], - * which is exactly the right position in NCHW layout. */ - float *out_tile = out_b + tile_start * outW; + /* Allocate temporary contiguous buffer for tile output */ + float *tmp = malloc((size_t)out_ch * tile_pixels * sizeof(float)); + if (!tmp) { + free(col); + goto naive_fallback; + } + + /* sgemm: tmp[out_ch, tile_pixels] = weight[out_ch, K] @ col[K, tile_pixels] */ cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, out_ch, tile_pixels, K, 1.0f, weight, K, col, tile_pixels, - 0.0f, out_tile, outH * outW); + 0.0f, tmp, tile_pixels); + + /* Scatter tile output to correct positions in out_b */ + for (int oc = 0; oc < out_ch; oc++) { + float *out_tile = out_b + oc * outH * outW + tile_start * outW; + float *tmp_row = tmp + oc * tile_pixels; + memcpy(out_tile, tmp_row, tile_pixels * sizeof(float)); + } + + free(tmp); } /* Add bias */ @@ -578,11 +1419,11 @@ void flux_silu(float *x, int n) { for (int i = 0; i < n; i++) { float val = x[i]; - x[i] = val / (1.0f + fast_expf(-val)); + x[i] = val / (1.0f + expf(-val)); } } -/* Fused SiLU(gate) * up in a single pass - avoids double memory traversal */ +/* Fused SiLU(gate) * up in a single pass. */ void flux_silu_mul(float *gate, const float *up, int n) { #ifdef USE_METAL if (flux_metal_shaders_available() && n >= 4 * 1024 * 1024) { @@ -590,15 +1431,22 @@ void flux_silu_mul(float *gate, const float *up, int n) { return; } #endif - for (int i = 0; i < n; i++) { - float val = gate[i]; - gate[i] = (val / (1.0f + fast_expf(-val))) * up[i]; + float v = gate[i]; + gate[i] = (v / (1.0f + expf(-v))) * up[i]; } } -/* CPU-only softmax. Safe to call from worker threads (no Metal dispatch). */ -void flux_softmax_cpu(float *x, int rows, int cols) { +void flux_softmax(float *x, int rows, int cols) { +#ifdef USE_METAL + /* Use GPU only for very large softmax operations + * Sync overhead usually dominates for smaller ops */ + if (flux_metal_shaders_available() && (size_t)rows * cols >= 4 * 1024 * 1024) { + flux_metal_softmax(x, rows, cols); + return; + } +#endif + for (int r = 0; r < rows; r++) { float *row = x + r * cols; @@ -611,7 +1459,7 @@ void flux_softmax_cpu(float *x, int rows, int cols) { /* Compute exp and sum */ float sum = 0.0f; for (int c = 0; c < cols; c++) { - row[c] = fast_expf(row[c] - max_val); + row[c] = expf(row[c] - max_val); sum += row[c]; } @@ -623,18 +1471,6 @@ void flux_softmax_cpu(float *x, int rows, int cols) { } } -void flux_softmax(float *x, int rows, int cols) { -#ifdef USE_METAL - /* Use GPU only for very large softmax operations - * Sync overhead usually dominates for smaller ops */ - if (flux_metal_shaders_available() && (size_t)rows * cols >= 4 * 1024 * 1024) { - flux_metal_softmax(x, rows, cols); - return; - } -#endif - flux_softmax_cpu(x, rows, cols); -} - /* ======================================================================== * Attention Operations * ======================================================================== */ @@ -737,7 +1573,7 @@ static void flash_attention_head(float *out, /* Online softmax update */ if (score > max_score) { /* New maximum found - rescale previous accumulations */ - float correction = fast_expf(max_score - score); + float correction = expf(max_score - score); sum_exp = sum_exp * correction + 1.0f; for (int d = 0; d < head_dim; d++) { o_row[d] = o_row[d] * correction + v_row[d]; @@ -745,7 +1581,7 @@ static void flash_attention_head(float *out, max_score = score; } else { /* Score is less than current max */ - float weight = fast_expf(score - max_score); + float weight = expf(score - max_score); sum_exp += weight; for (int d = 0; d < head_dim; d++) { o_row[d] += weight * v_row[d]; @@ -836,7 +1672,7 @@ static void flash_attention_head_tiled(float *out, /* Rescale old accumulations if needed */ if (old_max > -1e29f) { /* Check if we have prior accumulations */ - float correction = fast_expf(old_max - new_max); + float correction = expf(old_max - new_max); sum_exps[i] *= correction; for (int d = 0; d < head_dim; d++) { o_row[d] *= correction; @@ -845,7 +1681,7 @@ static void flash_attention_head_tiled(float *out, /* Accumulate this tile's contribution */ for (int ki = 0; ki < k_len; ki++) { - float weight = fast_expf(score_row[ki] - new_max); + float weight = expf(score_row[ki] - new_max); sum_exps[i] += weight; const float *v_row = V_tile + ki * head_dim; for (int d = 0; d < head_dim; d++) { diff --git a/flux_kernels.h b/flux_kernels.h index 910c667..8de70bf 100644 --- a/flux_kernels.h +++ b/flux_kernels.h @@ -83,6 +83,22 @@ void flux_linear_nobias(float *y, const float *x, const float *W, void flux_linear_nobias_bf16(float *y, const float *x, const uint16_t *W_bf16, int seq_len, int in_dim, int out_dim); +/* Returns 1 when CUDA bf16 linear GEMM is available on this runtime, else 0. + * On non-CUDA builds, always returns 0. */ +int flux_cuda_bf16_linear_available(void); + +/* CUDA-only helper: linear without bias with device input/output. + * d_y: device [seq_len, out_dim], d_x: device [seq_len, in_dim], W: host [out_dim, in_dim] + * Returns 1 on success, 0 on failure/fallback condition. + */ +int flux_cuda_linear_nobias_device(float *d_y, const float *d_x, const float *W, + int seq_len, int in_dim, int out_dim); + +/* Set the CUDA stream used by CUDA linear helpers. + * Pass NULL to restore the default stream. + * Returns 1 on success, 0 on failure/non-CUDA builds. */ +int flux_cuda_linear_set_stream(void *stream_handle); + /* ======================================================================== * GPU Batch Operations * These functions allow batching multiple GPU operations to reduce sync overhead. diff --git a/flux_qwen3.c b/flux_qwen3.c index 4b3cd04..5177245 100644 --- a/flux_qwen3.c +++ b/flux_qwen3.c @@ -30,6 +30,9 @@ #ifdef USE_METAL #include "flux_metal.h" #endif +#ifdef USE_CUDA +#include "flux_cuda.h" +#endif /* Minimum matrix size for GPU acceleration. * Using 10M threshold keeps text encoder on CPU (Accelerate BLAS), which is @@ -126,6 +129,11 @@ struct qwen3_model { float *attn_q_head; /* [seq_len, head_dim] */ float *attn_v_head; /* [seq_len, head_dim] */ float *attn_out_head; /* [seq_len, head_dim] */ + /* CUDA batched attention staging buffers [num_heads, seq_len, head_dim] */ + float *attn_q_hsd; + float *attn_k_hsd; + float *attn_v_hsd; + float *attn_out_hsd; /* Mmap mode: keep safetensors files open, load layer weights on-demand */ int use_mmap; @@ -169,11 +177,9 @@ static void qwen3_linear(float *y, const float *x, const float *W, } #endif -#ifdef USE_BLAS - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, - seq_len, out_dim, in_dim, - 1.0f, x, in_dim, W, in_dim, - 0.0f, y, out_dim); +#if defined(USE_BLAS) || defined(USE_CUDA) + /* Route through shared GEMM so CUDA builds can use GPU linear kernels. */ + flux_matmul_t(y, x, W, seq_len, in_dim, out_dim); #else for (int s = 0; s < seq_len; s++) { for (int o = 0; o < out_dim; o++) { @@ -366,6 +372,46 @@ static void qwen3_attention_forward(qwen3_model_t *model, qwen3_layer_t *layer, { int heads_per_kv = num_heads / num_kv_heads; +#ifdef USE_CUDA + /* CUDA batched GQA path: one attention launch for all query heads. + * Pack Q/K/V from [seq, hidden] strided layout to [heads, seq, head_dim]. */ + if (model->attn_q_hsd && model->attn_k_hsd && model->attn_v_hsd && model->attn_out_hsd) { + float *q_hsd = model->attn_q_hsd; + float *k_hsd = model->attn_k_hsd; + float *v_hsd = model->attn_v_hsd; + float *out_hsd = model->attn_out_hsd; + + for (int h = 0; h < num_heads; h++) { + int kv_h = h / heads_per_kv; + float *q_h = q_hsd + (size_t)h * seq_len * head_dim; + float *k_h = k_hsd + (size_t)h * seq_len * head_dim; + float *v_h = v_hsd + (size_t)h * seq_len * head_dim; + for (int s = 0; s < seq_len; s++) { + const float *q_src = model->q_buf + s * q_dim + h * head_dim; + const float *k_src = model->k_buf + s * kv_dim + kv_h * head_dim; + const float *v_src = model->v_buf + s * kv_dim + kv_h * head_dim; + memcpy(q_h + s * head_dim, q_src, head_dim * sizeof(float)); + memcpy(k_h + s * head_dim, k_src, head_dim * sizeof(float)); + memcpy(v_h + s * head_dim, v_src, head_dim * sizeof(float)); + } + } + + if (flux_cuda_attention_batched(out_hsd, q_hsd, k_hsd, v_hsd, + num_heads, seq_len, seq_len, head_dim, + scale, 1, attention_mask)) { + /* Unpack back to [seq, heads * head_dim]. */ + for (int h = 0; h < num_heads; h++) { + const float *out_h = out_hsd + (size_t)h * seq_len * head_dim; + for (int s = 0; s < seq_len; s++) { + float *dst = model->attn_out + s * q_dim + h * head_dim; + memcpy(dst, out_h + s * head_dim, head_dim * sizeof(float)); + } + } + goto output_proj; + } + } +#endif + for (int h = 0; h < num_heads; h++) { int kv_h = h / heads_per_kv; /* Which KV head to use */ float *scores = model->attn_scores + h * seq_len * seq_len; @@ -441,7 +487,7 @@ static void qwen3_attention_forward(qwen3_model_t *model, qwen3_layer_t *layer, /* Work buffers are pre-allocated in model, no free needed */ -#ifdef USE_METAL +#if defined(USE_METAL) || defined(USE_CUDA) output_proj: #endif /* Output projection */ @@ -1502,6 +1548,10 @@ static void qwen3_alloc_work_buffers(qwen3_model_t *model) { model->attn_q_head = malloc(seq_len * head_dim * sizeof(float)); model->attn_v_head = malloc(seq_len * head_dim * sizeof(float)); model->attn_out_head = malloc(seq_len * head_dim * sizeof(float)); + model->attn_q_hsd = malloc(seq_len * num_heads * head_dim * sizeof(float)); + model->attn_k_hsd = malloc(seq_len * num_heads * head_dim * sizeof(float)); + model->attn_v_hsd = malloc(seq_len * num_heads * head_dim * sizeof(float)); + model->attn_out_hsd = malloc(seq_len * num_heads * head_dim * sizeof(float)); for (int i = 0; i < 3; i++) { model->layer_outputs[i] = malloc(seq_len * hidden * sizeof(float)); @@ -1691,6 +1741,10 @@ void qwen3_model_free(qwen3_model_t *model) { free(model->attn_q_head); free(model->attn_v_head); free(model->attn_out_head); + free(model->attn_q_hsd); + free(model->attn_k_hsd); + free(model->attn_v_hsd); + free(model->attn_out_hsd); for (int i = 0; i < 3; i++) { free(model->layer_outputs[i]); diff --git a/flux_transformer.c b/flux_transformer.c index d69548d..7096476 100644 --- a/flux_transformer.c +++ b/flux_transformer.c @@ -29,50 +29,16 @@ extern double flux_timing_transformer_double; extern double flux_timing_transformer_single; extern double flux_timing_transformer_final; -/* Fine-grained profiling for BLAS optimization */ -static double prof_single_adaln = 0; -static double prof_single_fused_matmul = 0; -static double prof_single_split = 0; -static double prof_single_qknorm_rope = 0; -static double prof_single_attention = 0; -static double prof_single_swiglu = 0; -static double prof_single_proj_matmul = 0; -static double prof_single_gated_add = 0; - -static double prof_get_time(void) { +/* Helper to get current time in ms (wall-clock) */ +static double tf_get_time_ms(void) { struct timeval tv; gettimeofday(&tv, NULL); return tv.tv_sec * 1000.0 + tv.tv_usec / 1000.0; } +/* Compatibility hook used by sampler when BLAS profiling is enabled in + * some branches. This branch does not maintain per-op BLAS counters. */ void flux_print_blas_profile(void) { - double total = prof_single_adaln + prof_single_fused_matmul + prof_single_split + - prof_single_qknorm_rope + prof_single_attention + prof_single_swiglu + - prof_single_proj_matmul + prof_single_gated_add; - if (total < 1.0) return; - fprintf(stderr, "\nSingle block breakdown (cumulative):\n"); - fprintf(stderr, " AdaLN+mod: %7.1fms (%4.1f%%)\n", prof_single_adaln, 100*prof_single_adaln/total); - fprintf(stderr, " Fused QKV+MLP: %7.1fms (%4.1f%%)\n", prof_single_fused_matmul, 100*prof_single_fused_matmul/total); - fprintf(stderr, " Split: %7.1fms (%4.1f%%)\n", prof_single_split, 100*prof_single_split/total); - fprintf(stderr, " QKnorm+RoPE: %7.1fms (%4.1f%%)\n", prof_single_qknorm_rope, 100*prof_single_qknorm_rope/total); - fprintf(stderr, " Attention: %7.1fms (%4.1f%%)\n", prof_single_attention, 100*prof_single_attention/total); - fprintf(stderr, " SwiGLU: %7.1fms (%4.1f%%)\n", prof_single_swiglu, 100*prof_single_swiglu/total); - fprintf(stderr, " Proj matmul: %7.1fms (%4.1f%%)\n", prof_single_proj_matmul, 100*prof_single_proj_matmul/total); - fprintf(stderr, " Gated add: %7.1fms (%4.1f%%)\n", prof_single_gated_add, 100*prof_single_gated_add/total); - fprintf(stderr, " Total: %7.1fms\n", total); -} - -void flux_reset_blas_profile(void) { - prof_single_adaln = prof_single_fused_matmul = prof_single_split = 0; - prof_single_qknorm_rope = prof_single_attention = prof_single_swiglu = 0; - prof_single_proj_matmul = prof_single_gated_add = 0; -} - -/* Helper to get current time in ms (wall-clock) */ -static double tf_get_time_ms(void) { - struct timeval tv; - gettimeofday(&tv, NULL); - return tv.tv_sec * 1000.0 + tv.tv_usec / 1000.0; } /* Use BLAS for matrix operations when enabled via Makefile */ @@ -82,8 +48,6 @@ static double tf_get_time_ms(void) { #else #include #endif -#include -#include #endif /* Use Metal for GPU acceleration when available */ @@ -91,6 +55,11 @@ static double tf_get_time_ms(void) { #include "flux_metal.h" #endif +#ifdef USE_CUDA +#include "flux_cuda.h" +#include +#endif + /* Enable BF16 pipeline debug logging when FLUX_BF16_DEBUG is set. */ #ifdef USE_METAL static int bf16_debug_enabled(void) { @@ -295,6 +264,7 @@ typedef struct flux_transformer { float *t_emb_silu; /* [hidden] */ float *double_mod_img; /* [hidden * 6] */ float *double_mod_txt; /* [hidden * 6] */ + float *single_mod_params; /* [hidden * 3] */ float *double_img_attn_out; /* [max_seq, hidden] */ float *double_txt_attn_out; /* [max_seq, hidden] */ @@ -322,175 +292,14 @@ typedef struct flux_transformer { /* Mmap mode: keep safetensors file open, load block weights on-demand */ int use_mmap; - #define MAX_TF_SHARDS 4 - safetensors_file_t *sf_files[MAX_TF_SHARDS]; - int num_sf_files; + safetensors_file_t *sf; } flux_transformer_t; -/* ======================================================================== - * Transformer Config Parsing - * ======================================================================== */ - -/* Parse transformer/config.json to get architecture dimensions. - * Returns 0 on success, -1 on failure (caller should use defaults). */ -static int parse_transformer_config(const char *model_dir, flux_transformer_t *tf) { - char path[1024]; - snprintf(path, sizeof(path), "%s/transformer/config.json", model_dir); - - FILE *f = fopen(path, "r"); - if (!f) return -1; - - char buf[4096]; - size_t n = fread(buf, 1, sizeof(buf) - 1, f); - buf[n] = '\0'; - fclose(f); - - /* Simple JSON integer/float extraction. - * Look for "key": value patterns. */ - char *p; - int num_heads = 0, head_dim = 0, num_layers = 0, num_single = 0; - int joint_attention_dim = 0, in_channels = 0; - float mlp_ratio = 0, rope_theta = 0; - - if ((p = strstr(buf, "\"num_attention_heads\""))) { - if ((p = strchr(p, ':'))) num_heads = atoi(p + 1); - } - if ((p = strstr(buf, "\"attention_head_dim\""))) { - if ((p = strchr(p, ':'))) head_dim = atoi(p + 1); - } - if ((p = strstr(buf, "\"num_layers\""))) { - if ((p = strchr(p, ':'))) num_layers = atoi(p + 1); - } - if ((p = strstr(buf, "\"num_single_layers\""))) { - if ((p = strchr(p, ':'))) num_single = atoi(p + 1); - } - if ((p = strstr(buf, "\"joint_attention_dim\""))) { - if ((p = strchr(p, ':'))) joint_attention_dim = atoi(p + 1); - } - if ((p = strstr(buf, "\"in_channels\""))) { - if ((p = strchr(p, ':'))) in_channels = atoi(p + 1); - } - if ((p = strstr(buf, "\"mlp_ratio\""))) { - if ((p = strchr(p, ':'))) mlp_ratio = atof(p + 1); - } - if ((p = strstr(buf, "\"rope_theta\""))) { - if ((p = strchr(p, ':'))) rope_theta = atof(p + 1); - } - - /* Validate: we need at least heads and head_dim */ - if (num_heads <= 0 || head_dim <= 0) return -1; - - tf->num_heads = num_heads; - tf->head_dim = head_dim; - tf->hidden_size = num_heads * head_dim; - tf->mlp_hidden = (int)(tf->hidden_size * (mlp_ratio > 0 ? mlp_ratio : 3.0f)); - tf->num_double_layers = num_layers > 0 ? num_layers : 5; - tf->num_single_layers = num_single > 0 ? num_single : 20; - tf->text_dim = joint_attention_dim > 0 ? joint_attention_dim : 7680; - tf->latent_channels = in_channels > 0 ? in_channels : 128; - tf->rope_theta = rope_theta > 0 ? rope_theta : 2000.0f; - tf->rope_dim = head_dim; - tf->axis_dim = head_dim / 4; /* 4 RoPE axes */ - - return 0; -} - -/* Open transformer safetensors shards. - * Reads diffusion_pytorch_model.safetensors.index.json for shard filenames, - * falls back to single diffusion_pytorch_model.safetensors. - * Returns number of files opened (0 on failure). */ -static int open_transformer_shards(const char *model_dir, - safetensors_file_t **files, int max_files) { - char path[1024]; - int num_files = 0; - - /* Try to read index JSON for sharded models */ - snprintf(path, sizeof(path), "%s/transformer/diffusion_pytorch_model.safetensors.index.json", - model_dir); - FILE *fp = fopen(path, "r"); - if (fp) { - fseek(fp, 0, SEEK_END); - long len = ftell(fp); - fseek(fp, 0, SEEK_SET); - char *json = malloc(len + 1); - if (json) { - fread(json, 1, len, fp); - json[len] = '\0'; - fclose(fp); - - /* Extract unique shard filenames from weight_map values */ - char shard_names[MAX_TF_SHARDS][256]; - int num_shards = 0; - const char *p = strstr(json, "\"weight_map\""); - if (p) { - p = strchr(p, '{'); - if (p) p++; - while (p && num_shards < max_files) { - /* Find next value (shard filename) */ - const char *colon = strchr(p, ':'); - if (!colon) break; - const char *q1 = strchr(colon, '"'); - if (!q1) break; - q1++; - const char *q2 = strchr(q1, '"'); - if (!q2) break; - int slen = (int)(q2 - q1); - if (slen > 0 && slen < 256) { - char fname[256]; - memcpy(fname, q1, slen); - fname[slen] = '\0'; - /* Check if already seen */ - int dup = 0; - for (int i = 0; i < num_shards; i++) { - if (strcmp(shard_names[i], fname) == 0) { dup = 1; break; } - } - if (!dup) { - strcpy(shard_names[num_shards], fname); - num_shards++; - } - } - p = q2 + 1; - /* Skip to next key or end */ - const char *comma = strchr(p, ','); - const char *brace = strchr(p, '}'); - if (brace && (!comma || brace < comma)) break; - if (comma) p = comma + 1; else break; - } - } - free(json); - - /* Open each shard - all must succeed */ - for (int i = 0; i < num_shards; i++) { - snprintf(path, sizeof(path), "%s/transformer/%s", model_dir, shard_names[i]); - files[num_files] = safetensors_open(path); - if (files[num_files]) { - num_files++; - } else { - fprintf(stderr, "Error: failed to open transformer shard %s\n", shard_names[i]); - for (int j = 0; j < num_files; j++) safetensors_close(files[j]); - return 0; - } - } - if (num_files > 0) return num_files; - } else { - fclose(fp); - } - } - - /* Fallback: single file */ - snprintf(path, sizeof(path), "%s/transformer/diffusion_pytorch_model.safetensors", - model_dir); - files[0] = safetensors_open(path); - if (files[0]) return 1; - - return 0; -} - /* Forward declarations */ void flux_transformer_free(flux_transformer_t *tf); -static int load_double_block_weights(double_block_t *b, safetensors_file_t **files, int num_files, int idx, int h, int mlp, int use_bf16); +static int load_double_block_weights(double_block_t *b, safetensors_file_t *sf, int idx, int h, int mlp, int use_bf16); static void free_double_block_weights(double_block_t *b); -static int load_single_block_weights(single_block_t *b, safetensors_file_t **files, int num_files, int idx, int h, int mlp, int use_bf16); +static int load_single_block_weights(single_block_t *b, safetensors_file_t *sf, int idx, int h, int mlp, int use_bf16); static void free_single_block_weights(single_block_t *b); /* ======================================================================== @@ -498,62 +307,61 @@ static void free_single_block_weights(single_block_t *b); * ======================================================================== */ /* Helper to get tensor as f32 (used by mmap load functions) */ -static float *mmap_get_f32(safetensors_file_t **files, int num_files, const char *name) { - for (int f = 0; f < num_files; f++) { - const safetensor_t *t = safetensors_find(files[f], name); - if (t) return safetensors_get_f32(files[f], t); +static float *mmap_get_f32(safetensors_file_t *sf, const char *name) { + const safetensor_t *t = safetensors_find(sf, name); + if (!t) { + fprintf(stderr, "Error: required tensor %s not found\n", name); + return NULL; } - fprintf(stderr, "Error: required tensor %s not found\n", name); - return NULL; + return safetensors_get_f32(sf, t); } /* Helper to get tensor as bf16 direct pointer (used by mmap load functions) * Returns pointer into mmap'd region - caller must NOT free */ -static uint16_t *mmap_get_bf16(safetensors_file_t **files, int num_files, const char *name) { - for (int f = 0; f < num_files; f++) { - const safetensor_t *t = safetensors_find(files[f], name); - if (t && safetensor_is_bf16(t)) return safetensors_get_bf16_direct(files[f], t); - } - return NULL; +static uint16_t *mmap_get_bf16(safetensors_file_t *sf, const char *name) { + const safetensor_t *t = safetensors_find(sf, name); + if (!t) return NULL; + if (!safetensor_is_bf16(t)) return NULL; + return safetensors_get_bf16_direct(sf, t); } /* Load weights for a single double_block on-demand */ -static int load_double_block_weights(double_block_t *b, safetensors_file_t **files, - int num_files, int idx, int h, int mlp, int use_bf16) { +static int load_double_block_weights(double_block_t *b, safetensors_file_t *sf, + int idx, int h, int mlp, int use_bf16) { char name[256]; /* Image attention - QK norm weights (always f32) */ snprintf(name, sizeof(name), "transformer_blocks.%d.attn.norm_q.weight", idx); - b->img_norm_q_weight = mmap_get_f32(files, num_files, name); + b->img_norm_q_weight = mmap_get_f32(sf, name); snprintf(name, sizeof(name), "transformer_blocks.%d.attn.norm_k.weight", idx); - b->img_norm_k_weight = mmap_get_f32(files, num_files, name); + b->img_norm_k_weight = mmap_get_f32(sf, name); /* Image Q, K, V projections - skip f32 if bf16 available */ snprintf(name, sizeof(name), "transformer_blocks.%d.attn.to_q.weight", idx); - if (use_bf16) b->img_q_weight_bf16 = mmap_get_bf16(files, num_files, name); - if (!use_bf16) b->img_q_weight = mmap_get_f32(files, num_files, name); + if (use_bf16) b->img_q_weight_bf16 = mmap_get_bf16(sf, name); + if (!use_bf16) b->img_q_weight = mmap_get_f32(sf, name); snprintf(name, sizeof(name), "transformer_blocks.%d.attn.to_k.weight", idx); - if (use_bf16) b->img_k_weight_bf16 = mmap_get_bf16(files, num_files, name); - if (!use_bf16) b->img_k_weight = mmap_get_f32(files, num_files, name); + if (use_bf16) b->img_k_weight_bf16 = mmap_get_bf16(sf, name); + if (!use_bf16) b->img_k_weight = mmap_get_f32(sf, name); snprintf(name, sizeof(name), "transformer_blocks.%d.attn.to_v.weight", idx); - if (use_bf16) b->img_v_weight_bf16 = mmap_get_bf16(files, num_files, name); - if (!use_bf16) b->img_v_weight = mmap_get_f32(files, num_files, name); + if (use_bf16) b->img_v_weight_bf16 = mmap_get_bf16(sf, name); + if (!use_bf16) b->img_v_weight = mmap_get_f32(sf, name); snprintf(name, sizeof(name), "transformer_blocks.%d.attn.to_out.0.weight", idx); - if (use_bf16) b->img_proj_weight_bf16 = mmap_get_bf16(files, num_files, name); - if (!use_bf16) b->img_proj_weight = mmap_get_f32(files, num_files, name); + if (use_bf16) b->img_proj_weight_bf16 = mmap_get_bf16(sf, name); + if (!use_bf16) b->img_proj_weight = mmap_get_f32(sf, name); /* Image FFN - linear_in contains gate and up fused - skip f32 if bf16 available */ snprintf(name, sizeof(name), "transformer_blocks.%d.ff.linear_in.weight", idx); if (use_bf16) { - uint16_t *ff_in_bf16 = mmap_get_bf16(files, num_files, name); + uint16_t *ff_in_bf16 = mmap_get_bf16(sf, name); if (ff_in_bf16) { /* Direct pointers with offset - no malloc/copy needed */ b->img_mlp_gate_weight_bf16 = ff_in_bf16; b->img_mlp_up_weight_bf16 = ff_in_bf16 + (size_t)mlp * h; } } else { - float *ff_in = mmap_get_f32(files, num_files, name); + float *ff_in = mmap_get_f32(sf, name); if (ff_in) { b->img_mlp_gate_weight = malloc(mlp * h * sizeof(float)); b->img_mlp_up_weight = malloc(mlp * h * sizeof(float)); @@ -564,41 +372,41 @@ static int load_double_block_weights(double_block_t *b, safetensors_file_t **fil } snprintf(name, sizeof(name), "transformer_blocks.%d.ff.linear_out.weight", idx); - if (use_bf16) b->img_mlp_down_weight_bf16 = mmap_get_bf16(files, num_files, name); - if (!use_bf16) b->img_mlp_down_weight = mmap_get_f32(files, num_files, name); + if (use_bf16) b->img_mlp_down_weight_bf16 = mmap_get_bf16(sf, name); + if (!use_bf16) b->img_mlp_down_weight = mmap_get_f32(sf, name); /* Text stream - QK norm weights (always f32) */ snprintf(name, sizeof(name), "transformer_blocks.%d.attn.norm_added_q.weight", idx); - b->txt_norm_q_weight = mmap_get_f32(files, num_files, name); + b->txt_norm_q_weight = mmap_get_f32(sf, name); snprintf(name, sizeof(name), "transformer_blocks.%d.attn.norm_added_k.weight", idx); - b->txt_norm_k_weight = mmap_get_f32(files, num_files, name); + b->txt_norm_k_weight = mmap_get_f32(sf, name); /* Text Q, K, V projections - skip f32 if bf16 available */ snprintf(name, sizeof(name), "transformer_blocks.%d.attn.add_q_proj.weight", idx); - if (use_bf16) b->txt_q_weight_bf16 = mmap_get_bf16(files, num_files, name); - if (!use_bf16) b->txt_q_weight = mmap_get_f32(files, num_files, name); + if (use_bf16) b->txt_q_weight_bf16 = mmap_get_bf16(sf, name); + if (!use_bf16) b->txt_q_weight = mmap_get_f32(sf, name); snprintf(name, sizeof(name), "transformer_blocks.%d.attn.add_k_proj.weight", idx); - if (use_bf16) b->txt_k_weight_bf16 = mmap_get_bf16(files, num_files, name); - if (!use_bf16) b->txt_k_weight = mmap_get_f32(files, num_files, name); + if (use_bf16) b->txt_k_weight_bf16 = mmap_get_bf16(sf, name); + if (!use_bf16) b->txt_k_weight = mmap_get_f32(sf, name); snprintf(name, sizeof(name), "transformer_blocks.%d.attn.add_v_proj.weight", idx); - if (use_bf16) b->txt_v_weight_bf16 = mmap_get_bf16(files, num_files, name); - if (!use_bf16) b->txt_v_weight = mmap_get_f32(files, num_files, name); + if (use_bf16) b->txt_v_weight_bf16 = mmap_get_bf16(sf, name); + if (!use_bf16) b->txt_v_weight = mmap_get_f32(sf, name); snprintf(name, sizeof(name), "transformer_blocks.%d.attn.to_add_out.weight", idx); - if (use_bf16) b->txt_proj_weight_bf16 = mmap_get_bf16(files, num_files, name); - if (!use_bf16) b->txt_proj_weight = mmap_get_f32(files, num_files, name); + if (use_bf16) b->txt_proj_weight_bf16 = mmap_get_bf16(sf, name); + if (!use_bf16) b->txt_proj_weight = mmap_get_f32(sf, name); /* Text FFN - skip f32 if bf16 available */ snprintf(name, sizeof(name), "transformer_blocks.%d.ff_context.linear_in.weight", idx); if (use_bf16) { - uint16_t *txt_ff_in_bf16 = mmap_get_bf16(files, num_files, name); + uint16_t *txt_ff_in_bf16 = mmap_get_bf16(sf, name); if (txt_ff_in_bf16) { /* Direct pointers with offset - no malloc/copy needed */ b->txt_mlp_gate_weight_bf16 = txt_ff_in_bf16; b->txt_mlp_up_weight_bf16 = txt_ff_in_bf16 + (size_t)mlp * h; } } else { - float *txt_ff_in = mmap_get_f32(files, num_files, name); + float *txt_ff_in = mmap_get_f32(sf, name); if (txt_ff_in) { b->txt_mlp_gate_weight = malloc(mlp * h * sizeof(float)); b->txt_mlp_up_weight = malloc(mlp * h * sizeof(float)); @@ -609,8 +417,8 @@ static int load_double_block_weights(double_block_t *b, safetensors_file_t **fil } snprintf(name, sizeof(name), "transformer_blocks.%d.ff_context.linear_out.weight", idx); - if (use_bf16) b->txt_mlp_down_weight_bf16 = mmap_get_bf16(files, num_files, name); - if (!use_bf16) b->txt_mlp_down_weight = mmap_get_f32(files, num_files, name); + if (use_bf16) b->txt_mlp_down_weight_bf16 = mmap_get_bf16(sf, name); + if (!use_bf16) b->txt_mlp_down_weight = mmap_get_f32(sf, name); return 0; } @@ -618,11 +426,6 @@ static int load_double_block_weights(double_block_t *b, safetensors_file_t **fil /* Free weights for a single double_block (mmap mode only) * Note: bf16 pointers are direct mmap pointers, don't free them */ static void free_double_block_weights(double_block_t *b) { -#ifdef USE_METAL - /* Invalidate GPU weight cache before freeing CPU pointers. - * malloc can reuse freed addresses, causing stale cache hits. */ - flux_metal_clear_weight_cache_only(); -#endif free(b->img_norm_q_weight); b->img_norm_q_weight = NULL; free(b->img_norm_k_weight); b->img_norm_k_weight = NULL; free(b->img_q_weight); b->img_q_weight = NULL; @@ -660,26 +463,26 @@ static void free_double_block_weights(double_block_t *b) { } /* Load weights for a single single_block on-demand */ -static int load_single_block_weights(single_block_t *b, safetensors_file_t **files, - int num_files, int idx, int h, int mlp, int use_bf16) { +static int load_single_block_weights(single_block_t *b, safetensors_file_t *sf, + int idx, int h, int mlp, int use_bf16) { char name[256]; (void)h; (void)mlp; /* Unused in single block */ /* QK norm weights (always f32, small) */ snprintf(name, sizeof(name), "single_transformer_blocks.%d.attn.norm_q.weight", idx); - b->norm_q_weight = mmap_get_f32(files, num_files, name); + b->norm_q_weight = mmap_get_f32(sf, name); snprintf(name, sizeof(name), "single_transformer_blocks.%d.attn.norm_k.weight", idx); - b->norm_k_weight = mmap_get_f32(files, num_files, name); + b->norm_k_weight = mmap_get_f32(sf, name); /* Fused QKV+MLP input projection - skip f32 if bf16 available */ snprintf(name, sizeof(name), "single_transformer_blocks.%d.attn.to_qkv_mlp_proj.weight", idx); - if (use_bf16) b->qkv_mlp_weight_bf16 = mmap_get_bf16(files, num_files, name); - if (!use_bf16) b->qkv_mlp_weight = mmap_get_f32(files, num_files, name); + if (use_bf16) b->qkv_mlp_weight_bf16 = mmap_get_bf16(sf, name); + if (!use_bf16) b->qkv_mlp_weight = mmap_get_f32(sf, name); /* Fused attn out + MLP down projection - skip f32 if bf16 available */ snprintf(name, sizeof(name), "single_transformer_blocks.%d.attn.to_out.weight", idx); - if (use_bf16) b->proj_mlp_weight_bf16 = mmap_get_bf16(files, num_files, name); - if (!use_bf16) b->proj_mlp_weight = mmap_get_f32(files, num_files, name); + if (use_bf16) b->proj_mlp_weight_bf16 = mmap_get_bf16(sf, name); + if (!use_bf16) b->proj_mlp_weight = mmap_get_f32(sf, name); return 0; } @@ -687,9 +490,6 @@ static int load_single_block_weights(single_block_t *b, safetensors_file_t **fil /* Free weights for a single single_block (mmap mode only) * Note: bf16 pointers are direct mmap pointers, don't free them */ static void free_single_block_weights(single_block_t *b) { -#ifdef USE_METAL - flux_metal_clear_weight_cache_only(); -#endif free(b->norm_q_weight); b->norm_q_weight = NULL; free(b->norm_k_weight); b->norm_k_weight = NULL; free(b->qkv_mlp_weight); b->qkv_mlp_weight = NULL; @@ -699,16 +499,6 @@ static void free_single_block_weights(single_block_t *b) { b->proj_mlp_weight_bf16 = NULL; } -/* Free cached mmap weights for all blocks. - * Called after denoising completes to release memory held across steps. */ -void flux_transformer_free_mmap_cache(flux_transformer_t *tf) { - if (!tf || !tf->use_mmap) return; - for (int i = 0; i < tf->num_double_layers; i++) - free_double_block_weights(&tf->double_blocks[i]); - for (int i = 0; i < tf->num_single_layers; i++) - free_single_block_weights(&tf->single_blocks[i]); -} - #ifdef USE_METAL /* Pre-warm bf16 weight buffer cache for all blocks (mmap mode). * Loads each block's bf16 mmap pointers and copies weight data to Metal @@ -732,7 +522,7 @@ static void warmup_mmap_bf16_buffers(flux_transformer_t *tf) { /* Double blocks */ for (int i = 0; i < tf->num_double_layers; i++) { - load_double_block_weights(&tf->double_blocks[i], tf->sf_files, tf->num_sf_files, i, h, mlp, 1); + load_double_block_weights(&tf->double_blocks[i], tf->sf, i, h, mlp, 1); double_block_t *b = &tf->double_blocks[i]; if (b->img_q_weight_bf16) @@ -770,7 +560,7 @@ static void warmup_mmap_bf16_buffers(flux_transformer_t *tf) { /* Single blocks */ for (int i = 0; i < tf->num_single_layers; i++) { - load_single_block_weights(&tf->single_blocks[i], tf->sf_files, tf->num_sf_files, i, h, mlp, 1); + load_single_block_weights(&tf->single_blocks[i], tf->sf, i, h, mlp, 1); single_block_t *b = &tf->single_blocks[i]; if (b->qkv_mlp_weight_bf16) @@ -1332,7 +1122,7 @@ static void apply_qk_norm(float *q, float *k, /* Multi-head self-attention */ -#ifdef USE_METAL +#if defined(USE_METAL) || defined(USE_BLAS) /* Transpose from [seq, heads, head_dim] to [heads, seq, head_dim] * Needed for batched attention that processes each head separately */ static void transpose_shd_to_hsd(float *out, const float *in, @@ -1357,7 +1147,7 @@ static void transpose_hsd_to_shd(float *out, const float *in, } } } -#endif /* USE_METAL */ +#endif /* USE_METAL || USE_BLAS */ /* Ensure attn_scores buffer is large enough for current sequence lengths. * Only needed for BLAS/Metal paths - flash attention doesn't use this buffer. @@ -1464,105 +1254,6 @@ static int ensure_work_buffers(flux_transformer_t *tf, int total_seq) { return 0; } -/* ======================================================================== - * Thread-parallel attention for BLAS path. - * Per-head sgemm is too small for BLAS internal threading, so we - * parallelize across heads using pthreads instead. - * ======================================================================== */ - -#ifdef USE_BLAS -/* Work descriptor for self-attention (single blocks) */ -typedef struct { - const float *q, *k, *v; - float *out, *scores; - int seq, head_dim, hidden; - float scale; - int head_start, head_end; -} mha_thread_work_t; - -static void *mha_thread_worker(void *arg) { - mha_thread_work_t *w = (mha_thread_work_t *)arg; - for (int h = w->head_start; h < w->head_end; h++) { - const float *qh = w->q + h * w->head_dim; - const float *kh = w->k + h * w->head_dim; - const float *vh = w->v + h * w->head_dim; - float *oh = w->out + h * w->head_dim; - float *sh = w->scores + (size_t)h * w->seq * w->seq; - - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, - w->seq, w->seq, w->head_dim, - w->scale, qh, w->hidden, kh, w->hidden, - 0.0f, sh, w->seq); - flux_softmax_cpu(sh, w->seq, w->seq); - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, - w->seq, w->head_dim, w->seq, - 1.0f, sh, w->seq, vh, w->hidden, - 0.0f, oh, w->hidden); - } - return NULL; -} - -/* Work descriptor for joint attention (double blocks) */ -typedef struct { - const float *img_q, *txt_q, *cat_k, *cat_v; - float *img_out, *txt_out, *scores; - int img_seq, txt_seq, total_seq, head_dim, hidden; - float scale; - int head_start, head_end; -} joint_attn_thread_work_t; - -static void *joint_attn_thread_worker(void *arg) { - joint_attn_thread_work_t *w = (joint_attn_thread_work_t *)arg; - for (int h = w->head_start; h < w->head_end; h++) { - const float *img_qh = w->img_q + h * w->head_dim; - const float *txt_qh = w->txt_q + h * w->head_dim; - const float *kh = w->cat_k + h * w->head_dim; - const float *vh = w->cat_v + h * w->head_dim; - float *img_oh = w->img_out + h * w->head_dim; - float *txt_oh = w->txt_out + h * w->head_dim; - float *img_sh = w->scores + (size_t)h * w->total_seq * w->total_seq; - float *txt_sh = img_sh + (size_t)w->img_seq * w->total_seq; - - /* Image attention */ - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, - w->img_seq, w->total_seq, w->head_dim, - w->scale, img_qh, w->hidden, kh, w->hidden, - 0.0f, img_sh, w->total_seq); - flux_softmax_cpu(img_sh, w->img_seq, w->total_seq); - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, - w->img_seq, w->head_dim, w->total_seq, - 1.0f, img_sh, w->total_seq, vh, w->hidden, - 0.0f, img_oh, w->hidden); - - /* Text attention */ - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, - w->txt_seq, w->total_seq, w->head_dim, - w->scale, txt_qh, w->hidden, kh, w->hidden, - 0.0f, txt_sh, w->total_seq); - flux_softmax_cpu(txt_sh, w->txt_seq, w->total_seq); - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, - w->txt_seq, w->head_dim, w->total_seq, - 1.0f, txt_sh, w->total_seq, vh, w->hidden, - 0.0f, txt_oh, w->hidden); - } - return NULL; -} - -/* Get number of threads for head-parallel attention. - * Uses CPU core count, capped to divide num_heads evenly. */ -static int get_attn_num_threads(int heads) { - static int cached = 0; - if (cached) return cached; - int ncpu = (int)sysconf(_SC_NPROCESSORS_ONLN); - if (ncpu < 2) { cached = 1; return 1; } - if (ncpu > heads) ncpu = heads; - /* Round down to divide heads evenly */ - while (heads % ncpu != 0) ncpu--; - cached = ncpu; - return cached; -} -#endif /* USE_BLAS */ - /* Multi-head attention with BLAS optimization * Uses pre-allocated workspace buffers from transformer struct */ @@ -1601,58 +1292,78 @@ static void mha_forward(float *out, const float *q, const float *k, const float } #endif +#ifdef USE_CUDA + /* CUDA fast path for [seq, hidden] layout used by single-stream blocks. + * Avoids CPU SHD<->HSD transposes by doing them on-device. */ + if (!getenv("FLUX_CUDA_NO_SHD_ATTN") && + flux_cuda_attention_batched_shd(out, q, k, v, + tf->num_heads, seq, seq, head_dim, + scale, 0, NULL)) { + return; + } +#endif + /* CPU fallback: Use BLAS-optimized attention (faster) or flash attention (memory-efficient) */ #ifdef USE_BLAS - /* BLAS path: thread-parallel per-head attention. - * Q, K, V are [seq, heads*head_dim] layout. We use lda=hidden to stride - * over heads, reading head_dim elements per row directly. - * Per-head sgemm is too small for BLAS internal threading, so we - * parallelize across heads with pthreads for better core utilization. */ + /* BLAS path: transpose + batched matrix multiply per head */ { - int hidden = tf->num_heads * head_dim; + float *q_t = tf->attn_q_t; + float *k_t = tf->attn_k_t; + float *v_t = tf->attn_v_t; + float *out_t = tf->attn_out_t; float *scores = tf->attn_scores; - int nthreads = get_attn_num_threads(tf->num_heads); - int heads_per_thread = tf->num_heads / nthreads; - - if (nthreads <= 1) { - /* Serial fallback */ - for (int h = 0; h < tf->num_heads; h++) { - const float *qh = q + h * head_dim; - const float *kh = k + h * head_dim; - const float *vh = v + h * head_dim; - float *oh = out + h * head_dim; - float *sh = scores + (size_t)h * seq * seq; - - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, - seq, seq, head_dim, - scale, qh, hidden, kh, hidden, - 0.0f, sh, seq); - flux_softmax(sh, seq, seq); - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, - seq, head_dim, seq, - 1.0f, sh, seq, vh, hidden, - 0.0f, oh, hidden); - } - } else { - pthread_t threads[nthreads]; - mha_thread_work_t work[nthreads]; - int ok[nthreads]; - for (int t = 0; t < nthreads; t++) { - work[t] = (mha_thread_work_t){ - .q = q, .k = k, .v = v, - .out = out, .scores = scores, - .seq = seq, .head_dim = head_dim, .hidden = hidden, - .scale = scale, - .head_start = t * heads_per_thread, - .head_end = (t + 1) * heads_per_thread, - }; - ok[t] = pthread_create(&threads[t], NULL, mha_thread_worker, &work[t]) == 0; - if (!ok[t]) mha_thread_worker(&work[t]); - } - for (int t = 0; t < nthreads; t++) { - if (ok[t]) pthread_join(threads[t], NULL); + + /* Transpose to [heads, seq, head_dim] for efficient BLAS operations */ + transpose_shd_to_hsd(q_t, q, seq, tf->num_heads, head_dim); + transpose_shd_to_hsd(k_t, k, seq, tf->num_heads, head_dim); + transpose_shd_to_hsd(v_t, v, seq, tf->num_heads, head_dim); + +#ifdef USE_CUDA + if (flux_cuda_attention_batched(out_t, q_t, k_t, v_t, + tf->num_heads, seq, seq, head_dim, + scale, 0, NULL)) { + transpose_hsd_to_shd(out, out_t, seq, tf->num_heads, head_dim); + return; + } +#endif + + /* Process each head with BLAS */ + for (int h = 0; h < tf->num_heads; h++) { + float *qh = q_t + h * seq * head_dim; + float *kh = k_t + h * seq * head_dim; + float *vh = v_t + h * seq * head_dim; + float *oh = out_t + h * seq * head_dim; + float *sh = scores + h * seq * seq; + + /* scores = Q @ K^T using BLAS */ +#ifdef USE_CUDA + flux_matmul_t(sh, qh, kh, seq, head_dim, seq); + for (int i = 0; i < seq * seq; i++) { + sh[i] *= scale; } +#else + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + seq, seq, head_dim, + scale, qh, head_dim, kh, head_dim, + 0.0f, sh, seq); +#endif + + /* Softmax */ + flux_softmax(sh, seq, seq); + + /* out = scores @ V using BLAS */ +#ifdef USE_CUDA + flux_matmul(oh, sh, vh, seq, seq, head_dim); +#else + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, + seq, head_dim, seq, + 1.0f, sh, seq, vh, head_dim, + 0.0f, oh, head_dim); +#endif } + + /* Transpose output back to [seq, heads, head_dim] */ + transpose_hsd_to_shd(out, out_t, seq, tf->num_heads, head_dim); } #else /* Generic fallback: Use flash attention (memory-efficient, no transpose needed) */ @@ -1725,72 +1436,110 @@ static void joint_attention(float *img_out, float *txt_out, } #endif +#ifdef USE_CUDA + /* CUDA fast path for [seq, hidden] layout in double-block joint attention. + * cat_k/cat_v are already built in [total_seq, hidden] sequence-major layout. */ + if (!getenv("FLUX_CUDA_NO_SHD_ATTN") && + flux_cuda_attention_batched_shd(img_out, img_q, cat_k, cat_v, + heads, img_seq, total_seq, head_dim, + scale, 0, NULL) && + flux_cuda_attention_batched_shd(txt_out, txt_q, cat_k, cat_v, + heads, txt_seq, total_seq, head_dim, + scale, 0, NULL)) { + return; + } +#endif + /* CPU fallback: Use BLAS-optimized attention (faster) or flash attention (memory-efficient) */ #ifdef USE_BLAS - /* BLAS path: thread-parallel per-head joint attention. - * All tensors are [seq, heads*head_dim] layout, use lda=hidden for strides. - * Each head gets its own scores slice for thread safety. */ + /* BLAS path: transpose + batched matrix multiply per head */ { + float *img_q_t = tf->attn_q_t; + float *txt_q_t = tf->attn_q_t + img_seq * hidden; + float *cat_k_t = tf->attn_k_t; + float *cat_v_t = tf->attn_v_t; + float *img_out_t = tf->attn_out_t; + float *txt_out_t = tf->attn_out_t + img_seq * hidden; float *scores = tf->attn_scores; - int nthreads = get_attn_num_threads(heads); - int heads_per_thread = heads / nthreads; - - if (nthreads <= 1) { - /* Serial fallback */ - for (int h = 0; h < heads; h++) { - const float *img_qh = img_q + h * head_dim; - const float *txt_qh = txt_q + h * head_dim; - const float *kh = cat_k + h * head_dim; - const float *vh = cat_v + h * head_dim; - float *img_oh = img_out + h * head_dim; - float *txt_oh = txt_out + h * head_dim; - float *img_sh = scores + (size_t)h * total_seq * total_seq; - float *txt_sh = img_sh + (size_t)img_seq * total_seq; - - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, - img_seq, total_seq, head_dim, - scale, img_qh, hidden, kh, hidden, - 0.0f, img_sh, total_seq); - flux_softmax(img_sh, img_seq, total_seq); - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, - img_seq, head_dim, total_seq, - 1.0f, img_sh, total_seq, vh, hidden, - 0.0f, img_oh, hidden); - - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, - txt_seq, total_seq, head_dim, - scale, txt_qh, hidden, kh, hidden, - 0.0f, txt_sh, total_seq); - flux_softmax(txt_sh, txt_seq, total_seq); - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, - txt_seq, head_dim, total_seq, - 1.0f, txt_sh, total_seq, vh, hidden, - 0.0f, txt_oh, hidden); - } - } else { - pthread_t threads[nthreads]; - joint_attn_thread_work_t work[nthreads]; - int ok[nthreads]; - for (int t = 0; t < nthreads; t++) { - work[t] = (joint_attn_thread_work_t){ - .img_q = img_q, .txt_q = txt_q, - .cat_k = cat_k, .cat_v = cat_v, - .img_out = img_out, .txt_out = txt_out, - .scores = scores, - .img_seq = img_seq, .txt_seq = txt_seq, - .total_seq = total_seq, - .head_dim = head_dim, .hidden = hidden, - .scale = scale, - .head_start = t * heads_per_thread, - .head_end = (t + 1) * heads_per_thread, - }; - ok[t] = pthread_create(&threads[t], NULL, joint_attn_thread_worker, &work[t]) == 0; - if (!ok[t]) joint_attn_thread_worker(&work[t]); + + /* Transpose to [heads, seq, head_dim] for efficient BLAS operations */ + transpose_shd_to_hsd(img_q_t, img_q, img_seq, heads, head_dim); + transpose_shd_to_hsd(txt_q_t, txt_q, txt_seq, heads, head_dim); + transpose_shd_to_hsd(cat_k_t, cat_k, total_seq, heads, head_dim); + transpose_shd_to_hsd(cat_v_t, cat_v, total_seq, heads, head_dim); + +#ifdef USE_CUDA + if (flux_cuda_attention_batched(img_out_t, img_q_t, cat_k_t, cat_v_t, + heads, img_seq, total_seq, head_dim, + scale, 0, NULL) && + flux_cuda_attention_batched(txt_out_t, txt_q_t, cat_k_t, cat_v_t, + heads, txt_seq, total_seq, head_dim, + scale, 0, NULL)) { + transpose_hsd_to_shd(img_out, img_out_t, img_seq, heads, head_dim); + transpose_hsd_to_shd(txt_out, txt_out_t, txt_seq, heads, head_dim); + return; + } +#endif + + /* Process each head with BLAS */ + for (int h = 0; h < heads; h++) { + float *img_qh = img_q_t + h * img_seq * head_dim; + float *txt_qh = txt_q_t + h * txt_seq * head_dim; + float *kh = cat_k_t + h * total_seq * head_dim; + float *vh = cat_v_t + h * total_seq * head_dim; + float *img_oh = img_out_t + h * img_seq * head_dim; + float *txt_oh = txt_out_t + h * txt_seq * head_dim; + float *img_sh = scores; /* Reuse scores buffer */ + float *txt_sh = scores + img_seq * total_seq; + + /* Image attention: img_Q @ cat_K^T */ +#ifdef USE_CUDA + flux_matmul_t(img_sh, img_qh, kh, img_seq, head_dim, total_seq); + for (int i = 0; i < img_seq * total_seq; i++) { + img_sh[i] *= scale; } - for (int t = 0; t < nthreads; t++) { - if (ok[t]) pthread_join(threads[t], NULL); +#else + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + img_seq, total_seq, head_dim, + scale, img_qh, head_dim, kh, head_dim, + 0.0f, img_sh, total_seq); +#endif + flux_softmax(img_sh, img_seq, total_seq); +#ifdef USE_CUDA + flux_matmul(img_oh, img_sh, vh, img_seq, total_seq, head_dim); +#else + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, + img_seq, head_dim, total_seq, + 1.0f, img_sh, total_seq, vh, head_dim, + 0.0f, img_oh, head_dim); +#endif + + /* Text attention: txt_Q @ cat_K^T */ +#ifdef USE_CUDA + flux_matmul_t(txt_sh, txt_qh, kh, txt_seq, head_dim, total_seq); + for (int i = 0; i < txt_seq * total_seq; i++) { + txt_sh[i] *= scale; } +#else + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + txt_seq, total_seq, head_dim, + scale, txt_qh, head_dim, kh, head_dim, + 0.0f, txt_sh, total_seq); +#endif + flux_softmax(txt_sh, txt_seq, total_seq); +#ifdef USE_CUDA + flux_matmul(txt_oh, txt_sh, vh, txt_seq, total_seq, head_dim); +#else + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, + txt_seq, head_dim, total_seq, + 1.0f, txt_sh, total_seq, vh, head_dim, + 0.0f, txt_oh, head_dim); +#endif } + + /* Transpose outputs back */ + transpose_hsd_to_shd(img_out, img_out_t, img_seq, heads, head_dim); + transpose_hsd_to_shd(txt_out, txt_out_t, txt_seq, heads, head_dim); } #else /* Generic fallback: Use flash attention (memory-efficient, no transpose needed) */ @@ -2094,8 +1843,9 @@ static void swiglu_ffn_bf16(float *out, const float *x, LINEAR_BF16_OR_F32(up, x, up_weight, up_weight_bf16, seq, hidden, mlp_hidden); flux_gpu_end_batch(); - /* SiLU(gate) * up - fused for better performance */ - flux_silu_mul(gate, up, seq * mlp_hidden); + /* SiLU(gate) * up */ + flux_silu(gate, seq * mlp_hidden); + flux_mul_inplace(gate, up, seq * mlp_hidden); /* Down projection */ LINEAR_BF16_OR_F32(out, gate, down_weight, down_weight_bf16, seq, mlp_hidden, hidden); @@ -2326,6 +2076,448 @@ static void double_block_forward(float *img_hidden, float *txt_hidden, #endif } +#ifdef USE_CUDA +/* CUDA-resident double-stream path: + * Keeps image/text hidden activations on device across all 5 double blocks. + */ +static int double_blocks_forward_cuda_resident(float *img_hidden, float *txt_hidden, + flux_transformer_t *tf, + const float *img_mod, const float *txt_mod, + const float *img_rope_cos, const float *img_rope_sin, + const float *txt_rope_cos, const float *txt_rope_sin, + int img_seq, int txt_seq) { + if (!img_hidden || !txt_hidden || !tf || !img_mod || !txt_mod || + !img_rope_cos || !img_rope_sin || !txt_rope_cos || !txt_rope_sin) { + return 0; + } + if (tf->use_mmap || tf->use_bf16) { + return 0; + } + + int hidden = tf->hidden_size; + int heads = tf->num_heads; + int head_dim = tf->head_dim; + int mlp_hidden = tf->mlp_hidden; + int total_seq = img_seq + txt_seq; + float eps = 1e-6f; + float attn_scale = 1.0f / sqrtf((float)head_dim); + + const float *img_shift1 = img_mod; + const float *img_scale1 = img_mod + hidden; + const float *img_gate1 = img_mod + hidden * 2; + const float *img_shift2 = img_mod + hidden * 3; + const float *img_scale2 = img_mod + hidden * 4; + const float *img_gate2 = img_mod + hidden * 5; + + const float *txt_shift1 = txt_mod; + const float *txt_scale1 = txt_mod + hidden; + const float *txt_gate1 = txt_mod + hidden * 2; + const float *txt_shift2 = txt_mod + hidden * 3; + const float *txt_scale2 = txt_mod + hidden * 4; + const float *txt_gate2 = txt_mod + hidden * 5; + + typedef struct { + const flux_transformer_t *tf; + int img_seq, txt_seq; + const float *host_img_rope_cos, *host_img_rope_sin; + const float *host_txt_rope_cos, *host_txt_rope_sin; + int streams_ready; + cudaStream_t stream_main, stream_img, stream_txt; + float *d_img_hidden, *d_txt_hidden; + float *d_img_norm, *d_txt_norm; + float *d_img_q, *d_img_k, *d_img_v; + float *d_txt_q, *d_txt_k, *d_txt_v; + float *d_cat_k, *d_cat_v; + float *d_img_attn_out, *d_txt_attn_out; + float *d_img_proj, *d_txt_proj; + float *d_img_gate_ffn, *d_img_up_ffn, *d_img_down; + float *d_txt_gate_ffn, *d_txt_up_ffn, *d_txt_down; + float *d_img_shift1, *d_img_scale1, *d_img_gate1; + float *d_img_shift2, *d_img_scale2, *d_img_gate2; + float *d_txt_shift1, *d_txt_scale1, *d_txt_gate1; + float *d_txt_shift2, *d_txt_scale2, *d_txt_gate2; + float *d_img_rope_cos, *d_img_rope_sin; + float *d_txt_rope_cos, *d_txt_rope_sin; + float *d_img_norm_q_all, *d_img_norm_k_all; + float *d_txt_norm_q_all, *d_txt_norm_k_all; + } cuda_double_ctx_t; + + static cuda_double_ctx_t s = {0}; + int ok = 0; + int parallel = getenv("FLUX_CUDA_DOUBLE_PARALLEL") ? 1 : 0; + +#define CUDA_FREE_PTR(ptr) do { if (ptr) { cudaFree(ptr); ptr = NULL; } } while (0) +#define CUDA_ALLOC_PTR(ptr, elems) \ + do { \ + if (cudaMalloc((void **)&(ptr), (size_t)(elems) * sizeof(float)) != cudaSuccess) goto fail; \ + } while (0) +#define CUDA_H2D(dst, src, elems) \ + do { \ + if (cudaMemcpy((dst), (src), (size_t)(elems) * sizeof(float), cudaMemcpyHostToDevice) != cudaSuccess) goto fail; \ + } while (0) +#define CUDA_D2H(dst, src, elems) \ + do { \ + if (cudaMemcpy((dst), (src), (size_t)(elems) * sizeof(float), cudaMemcpyDeviceToHost) != cudaSuccess) goto fail; \ + } while (0) +#define CUDA_SET_STREAM(stream_) \ + do { \ + if (!flux_cuda_linear_set_stream((void *)(stream_))) goto fail; \ + if (!flux_cuda_ops_set_stream((void *)(stream_))) goto fail; \ + } while (0) + + int need_realloc = (!s.d_img_hidden || + s.tf != tf || + s.img_seq != img_seq || + s.txt_seq != txt_seq); + if (need_realloc) { + CUDA_FREE_PTR(s.d_img_hidden); + CUDA_FREE_PTR(s.d_txt_hidden); + CUDA_FREE_PTR(s.d_img_norm); + CUDA_FREE_PTR(s.d_txt_norm); + CUDA_FREE_PTR(s.d_img_q); + CUDA_FREE_PTR(s.d_img_k); + CUDA_FREE_PTR(s.d_img_v); + CUDA_FREE_PTR(s.d_txt_q); + CUDA_FREE_PTR(s.d_txt_k); + CUDA_FREE_PTR(s.d_txt_v); + CUDA_FREE_PTR(s.d_cat_k); + CUDA_FREE_PTR(s.d_cat_v); + CUDA_FREE_PTR(s.d_img_attn_out); + CUDA_FREE_PTR(s.d_txt_attn_out); + CUDA_FREE_PTR(s.d_img_proj); + CUDA_FREE_PTR(s.d_txt_proj); + CUDA_FREE_PTR(s.d_img_gate_ffn); + CUDA_FREE_PTR(s.d_img_up_ffn); + CUDA_FREE_PTR(s.d_img_down); + CUDA_FREE_PTR(s.d_txt_gate_ffn); + CUDA_FREE_PTR(s.d_txt_up_ffn); + CUDA_FREE_PTR(s.d_txt_down); + CUDA_FREE_PTR(s.d_img_shift1); + CUDA_FREE_PTR(s.d_img_scale1); + CUDA_FREE_PTR(s.d_img_gate1); + CUDA_FREE_PTR(s.d_img_shift2); + CUDA_FREE_PTR(s.d_img_scale2); + CUDA_FREE_PTR(s.d_img_gate2); + CUDA_FREE_PTR(s.d_txt_shift1); + CUDA_FREE_PTR(s.d_txt_scale1); + CUDA_FREE_PTR(s.d_txt_gate1); + CUDA_FREE_PTR(s.d_txt_shift2); + CUDA_FREE_PTR(s.d_txt_scale2); + CUDA_FREE_PTR(s.d_txt_gate2); + CUDA_FREE_PTR(s.d_img_rope_cos); + CUDA_FREE_PTR(s.d_img_rope_sin); + CUDA_FREE_PTR(s.d_txt_rope_cos); + CUDA_FREE_PTR(s.d_txt_rope_sin); + CUDA_FREE_PTR(s.d_img_norm_q_all); + CUDA_FREE_PTR(s.d_img_norm_k_all); + CUDA_FREE_PTR(s.d_txt_norm_q_all); + CUDA_FREE_PTR(s.d_txt_norm_k_all); + + CUDA_ALLOC_PTR(s.d_img_hidden, (size_t)img_seq * hidden); + CUDA_ALLOC_PTR(s.d_txt_hidden, (size_t)txt_seq * hidden); + CUDA_ALLOC_PTR(s.d_img_norm, (size_t)img_seq * hidden); + CUDA_ALLOC_PTR(s.d_txt_norm, (size_t)txt_seq * hidden); + CUDA_ALLOC_PTR(s.d_img_q, (size_t)img_seq * hidden); + CUDA_ALLOC_PTR(s.d_img_k, (size_t)img_seq * hidden); + CUDA_ALLOC_PTR(s.d_img_v, (size_t)img_seq * hidden); + CUDA_ALLOC_PTR(s.d_txt_q, (size_t)txt_seq * hidden); + CUDA_ALLOC_PTR(s.d_txt_k, (size_t)txt_seq * hidden); + CUDA_ALLOC_PTR(s.d_txt_v, (size_t)txt_seq * hidden); + CUDA_ALLOC_PTR(s.d_cat_k, (size_t)total_seq * hidden); + CUDA_ALLOC_PTR(s.d_cat_v, (size_t)total_seq * hidden); + CUDA_ALLOC_PTR(s.d_img_attn_out, (size_t)img_seq * hidden); + CUDA_ALLOC_PTR(s.d_txt_attn_out, (size_t)txt_seq * hidden); + CUDA_ALLOC_PTR(s.d_img_proj, (size_t)img_seq * hidden); + CUDA_ALLOC_PTR(s.d_txt_proj, (size_t)txt_seq * hidden); + CUDA_ALLOC_PTR(s.d_img_gate_ffn, (size_t)img_seq * mlp_hidden); + CUDA_ALLOC_PTR(s.d_img_up_ffn, (size_t)img_seq * mlp_hidden); + CUDA_ALLOC_PTR(s.d_img_down, (size_t)img_seq * hidden); + CUDA_ALLOC_PTR(s.d_txt_gate_ffn, (size_t)txt_seq * mlp_hidden); + CUDA_ALLOC_PTR(s.d_txt_up_ffn, (size_t)txt_seq * mlp_hidden); + CUDA_ALLOC_PTR(s.d_txt_down, (size_t)txt_seq * hidden); + + CUDA_ALLOC_PTR(s.d_img_shift1, hidden); + CUDA_ALLOC_PTR(s.d_img_scale1, hidden); + CUDA_ALLOC_PTR(s.d_img_gate1, hidden); + CUDA_ALLOC_PTR(s.d_img_shift2, hidden); + CUDA_ALLOC_PTR(s.d_img_scale2, hidden); + CUDA_ALLOC_PTR(s.d_img_gate2, hidden); + CUDA_ALLOC_PTR(s.d_txt_shift1, hidden); + CUDA_ALLOC_PTR(s.d_txt_scale1, hidden); + CUDA_ALLOC_PTR(s.d_txt_gate1, hidden); + CUDA_ALLOC_PTR(s.d_txt_shift2, hidden); + CUDA_ALLOC_PTR(s.d_txt_scale2, hidden); + CUDA_ALLOC_PTR(s.d_txt_gate2, hidden); + + CUDA_ALLOC_PTR(s.d_img_rope_cos, (size_t)img_seq * head_dim); + CUDA_ALLOC_PTR(s.d_img_rope_sin, (size_t)img_seq * head_dim); + CUDA_ALLOC_PTR(s.d_txt_rope_cos, (size_t)txt_seq * head_dim); + CUDA_ALLOC_PTR(s.d_txt_rope_sin, (size_t)txt_seq * head_dim); + CUDA_ALLOC_PTR(s.d_img_norm_q_all, (size_t)tf->num_double_layers * head_dim); + CUDA_ALLOC_PTR(s.d_img_norm_k_all, (size_t)tf->num_double_layers * head_dim); + CUDA_ALLOC_PTR(s.d_txt_norm_q_all, (size_t)tf->num_double_layers * head_dim); + CUDA_ALLOC_PTR(s.d_txt_norm_k_all, (size_t)tf->num_double_layers * head_dim); + + for (int i = 0; i < tf->num_double_layers; i++) { + const double_block_t *b = &tf->double_blocks[i]; + if (!b->img_norm_q_weight || !b->img_norm_k_weight || + !b->txt_norm_q_weight || !b->txt_norm_k_weight) { + goto fail; + } + CUDA_H2D(s.d_img_norm_q_all + (size_t)i * head_dim, b->img_norm_q_weight, head_dim); + CUDA_H2D(s.d_img_norm_k_all + (size_t)i * head_dim, b->img_norm_k_weight, head_dim); + CUDA_H2D(s.d_txt_norm_q_all + (size_t)i * head_dim, b->txt_norm_q_weight, head_dim); + CUDA_H2D(s.d_txt_norm_k_all + (size_t)i * head_dim, b->txt_norm_k_weight, head_dim); + } + + if (!s.stream_main) { + if (cudaStreamCreateWithFlags(&s.stream_main, cudaStreamNonBlocking) != cudaSuccess) goto fail; + } + if (parallel && !s.streams_ready) { + if (cudaStreamCreateWithFlags(&s.stream_img, cudaStreamNonBlocking) != cudaSuccess) goto fail; + if (cudaStreamCreateWithFlags(&s.stream_txt, cudaStreamNonBlocking) != cudaSuccess) goto fail; + s.streams_ready = 1; + } + + s.tf = tf; + s.img_seq = img_seq; + s.txt_seq = txt_seq; + s.host_img_rope_cos = NULL; + s.host_img_rope_sin = NULL; + s.host_txt_rope_cos = NULL; + s.host_txt_rope_sin = NULL; + } + + CUDA_H2D(s.d_img_hidden, img_hidden, (size_t)img_seq * hidden); + CUDA_H2D(s.d_txt_hidden, txt_hidden, (size_t)txt_seq * hidden); + CUDA_H2D(s.d_img_shift1, img_shift1, hidden); + CUDA_H2D(s.d_img_scale1, img_scale1, hidden); + CUDA_H2D(s.d_img_gate1, img_gate1, hidden); + CUDA_H2D(s.d_img_shift2, img_shift2, hidden); + CUDA_H2D(s.d_img_scale2, img_scale2, hidden); + CUDA_H2D(s.d_img_gate2, img_gate2, hidden); + CUDA_H2D(s.d_txt_shift1, txt_shift1, hidden); + CUDA_H2D(s.d_txt_scale1, txt_scale1, hidden); + CUDA_H2D(s.d_txt_gate1, txt_gate1, hidden); + CUDA_H2D(s.d_txt_shift2, txt_shift2, hidden); + CUDA_H2D(s.d_txt_scale2, txt_scale2, hidden); + CUDA_H2D(s.d_txt_gate2, txt_gate2, hidden); + + if (s.host_img_rope_cos != img_rope_cos || s.host_img_rope_sin != img_rope_sin) { + CUDA_H2D(s.d_img_rope_cos, img_rope_cos, (size_t)img_seq * head_dim); + CUDA_H2D(s.d_img_rope_sin, img_rope_sin, (size_t)img_seq * head_dim); + s.host_img_rope_cos = img_rope_cos; + s.host_img_rope_sin = img_rope_sin; + } + if (s.host_txt_rope_cos != txt_rope_cos || s.host_txt_rope_sin != txt_rope_sin) { + CUDA_H2D(s.d_txt_rope_cos, txt_rope_cos, (size_t)txt_seq * head_dim); + CUDA_H2D(s.d_txt_rope_sin, txt_rope_sin, (size_t)txt_seq * head_dim); + s.host_txt_rope_cos = txt_rope_cos; + s.host_txt_rope_sin = txt_rope_sin; + } + + CUDA_SET_STREAM(s.stream_main); + + for (int i = 0; i < tf->num_double_layers; i++) { + const double_block_t *b = &tf->double_blocks[i]; + const float *d_img_norm_q = s.d_img_norm_q_all + (size_t)i * head_dim; + const float *d_img_norm_k = s.d_img_norm_k_all + (size_t)i * head_dim; + const float *d_txt_norm_q = s.d_txt_norm_q_all + (size_t)i * head_dim; + const float *d_txt_norm_k = s.d_txt_norm_k_all + (size_t)i * head_dim; + + if (!b->img_q_weight || !b->img_k_weight || !b->img_v_weight || + !b->txt_q_weight || !b->txt_k_weight || !b->txt_v_weight || + !b->img_proj_weight || !b->txt_proj_weight || + !b->img_mlp_gate_weight || !b->img_mlp_up_weight || !b->img_mlp_down_weight || + !b->txt_mlp_gate_weight || !b->txt_mlp_up_weight || !b->txt_mlp_down_weight) { + goto fail; + } + + if (parallel && s.streams_ready) { + CUDA_SET_STREAM(s.stream_img); + if (!flux_cuda_adaln_norm_device(s.d_img_norm, s.d_img_hidden, s.d_img_shift1, s.d_img_scale1, + img_seq, hidden, eps)) goto fail; + CUDA_SET_STREAM(s.stream_txt); + if (!flux_cuda_adaln_norm_device(s.d_txt_norm, s.d_txt_hidden, s.d_txt_shift1, s.d_txt_scale1, + txt_seq, hidden, eps)) goto fail; + if (cudaStreamSynchronize(s.stream_img) != cudaSuccess) goto fail; + if (cudaStreamSynchronize(s.stream_txt) != cudaSuccess) goto fail; + + CUDA_SET_STREAM(s.stream_img); + if (!flux_cuda_linear_nobias_device(s.d_img_q, s.d_img_norm, b->img_q_weight, img_seq, hidden, hidden)) goto fail; + if (!flux_cuda_linear_nobias_device(s.d_img_k, s.d_img_norm, b->img_k_weight, img_seq, hidden, hidden)) goto fail; + if (!flux_cuda_linear_nobias_device(s.d_img_v, s.d_img_norm, b->img_v_weight, img_seq, hidden, hidden)) goto fail; + CUDA_SET_STREAM(s.stream_txt); + if (!flux_cuda_linear_nobias_device(s.d_txt_q, s.d_txt_norm, b->txt_q_weight, txt_seq, hidden, hidden)) goto fail; + if (!flux_cuda_linear_nobias_device(s.d_txt_k, s.d_txt_norm, b->txt_k_weight, txt_seq, hidden, hidden)) goto fail; + if (!flux_cuda_linear_nobias_device(s.d_txt_v, s.d_txt_norm, b->txt_v_weight, txt_seq, hidden, hidden)) goto fail; + if (cudaStreamSynchronize(s.stream_img) != cudaSuccess) goto fail; + if (cudaStreamSynchronize(s.stream_txt) != cudaSuccess) goto fail; + + CUDA_SET_STREAM(s.stream_img); + if (!flux_cuda_qk_rms_norm_device(s.d_img_q, s.d_img_k, d_img_norm_q, d_img_norm_k, + img_seq, heads, head_dim, eps)) goto fail; + if (!flux_cuda_rope_unified_device(s.d_img_q, s.d_img_k, + s.d_img_rope_cos, s.d_img_rope_sin, + s.d_img_rope_cos, s.d_img_rope_sin, + img_seq, 0, heads, head_dim)) goto fail; + CUDA_SET_STREAM(s.stream_txt); + if (!flux_cuda_qk_rms_norm_device(s.d_txt_q, s.d_txt_k, d_txt_norm_q, d_txt_norm_k, + txt_seq, heads, head_dim, eps)) goto fail; + if (!flux_cuda_rope_unified_device(s.d_txt_q, s.d_txt_k, + s.d_txt_rope_cos, s.d_txt_rope_sin, + s.d_txt_rope_cos, s.d_txt_rope_sin, + txt_seq, txt_seq, heads, head_dim)) goto fail; + if (cudaStreamSynchronize(s.stream_img) != cudaSuccess) goto fail; + if (cudaStreamSynchronize(s.stream_txt) != cudaSuccess) goto fail; + } else { + CUDA_SET_STREAM(s.stream_main); + if (!flux_cuda_adaln_norm_device(s.d_img_norm, s.d_img_hidden, s.d_img_shift1, s.d_img_scale1, + img_seq, hidden, eps)) goto fail; + if (!flux_cuda_adaln_norm_device(s.d_txt_norm, s.d_txt_hidden, s.d_txt_shift1, s.d_txt_scale1, + txt_seq, hidden, eps)) goto fail; + + if (!flux_cuda_linear_nobias_device(s.d_img_q, s.d_img_norm, b->img_q_weight, img_seq, hidden, hidden)) goto fail; + if (!flux_cuda_linear_nobias_device(s.d_img_k, s.d_img_norm, b->img_k_weight, img_seq, hidden, hidden)) goto fail; + if (!flux_cuda_linear_nobias_device(s.d_img_v, s.d_img_norm, b->img_v_weight, img_seq, hidden, hidden)) goto fail; + if (!flux_cuda_linear_nobias_device(s.d_txt_q, s.d_txt_norm, b->txt_q_weight, txt_seq, hidden, hidden)) goto fail; + if (!flux_cuda_linear_nobias_device(s.d_txt_k, s.d_txt_norm, b->txt_k_weight, txt_seq, hidden, hidden)) goto fail; + if (!flux_cuda_linear_nobias_device(s.d_txt_v, s.d_txt_norm, b->txt_v_weight, txt_seq, hidden, hidden)) goto fail; + + if (!flux_cuda_qk_rms_norm_device(s.d_img_q, s.d_img_k, d_img_norm_q, d_img_norm_k, + img_seq, heads, head_dim, eps)) goto fail; + if (!flux_cuda_qk_rms_norm_device(s.d_txt_q, s.d_txt_k, d_txt_norm_q, d_txt_norm_k, + txt_seq, heads, head_dim, eps)) goto fail; + if (!flux_cuda_rope_unified_device(s.d_img_q, s.d_img_k, + s.d_img_rope_cos, s.d_img_rope_sin, + s.d_img_rope_cos, s.d_img_rope_sin, + img_seq, 0, heads, head_dim)) goto fail; + if (!flux_cuda_rope_unified_device(s.d_txt_q, s.d_txt_k, + s.d_txt_rope_cos, s.d_txt_rope_sin, + s.d_txt_rope_cos, s.d_txt_rope_sin, + txt_seq, txt_seq, heads, head_dim)) goto fail; + } + + CUDA_SET_STREAM(s.stream_main); + if (!flux_cuda_concat_seq_device(s.d_cat_k, s.d_txt_k, s.d_img_k, txt_seq, img_seq, hidden)) goto fail; + if (!flux_cuda_concat_seq_device(s.d_cat_v, s.d_txt_v, s.d_img_v, txt_seq, img_seq, hidden)) goto fail; + + if (!flux_cuda_attention_batched_shd_device(s.d_img_attn_out, s.d_img_q, s.d_cat_k, s.d_cat_v, + heads, img_seq, total_seq, head_dim, + attn_scale, 0, NULL)) goto fail; + if (!flux_cuda_attention_batched_shd_device(s.d_txt_attn_out, s.d_txt_q, s.d_cat_k, s.d_cat_v, + heads, txt_seq, total_seq, head_dim, + attn_scale, 0, NULL)) goto fail; + + if (parallel && s.streams_ready) { + CUDA_SET_STREAM(s.stream_img); + if (!flux_cuda_linear_nobias_device(s.d_img_proj, s.d_img_attn_out, b->img_proj_weight, + img_seq, hidden, hidden)) goto fail; + if (!flux_cuda_gated_add_device(s.d_img_hidden, s.d_img_gate1, s.d_img_proj, img_seq, hidden)) goto fail; + CUDA_SET_STREAM(s.stream_txt); + if (!flux_cuda_linear_nobias_device(s.d_txt_proj, s.d_txt_attn_out, b->txt_proj_weight, + txt_seq, hidden, hidden)) goto fail; + if (!flux_cuda_gated_add_device(s.d_txt_hidden, s.d_txt_gate1, s.d_txt_proj, txt_seq, hidden)) goto fail; + if (cudaStreamSynchronize(s.stream_img) != cudaSuccess) goto fail; + if (cudaStreamSynchronize(s.stream_txt) != cudaSuccess) goto fail; + + CUDA_SET_STREAM(s.stream_img); + if (!flux_cuda_adaln_norm_device(s.d_img_norm, s.d_img_hidden, s.d_img_shift2, s.d_img_scale2, + img_seq, hidden, eps)) goto fail; + CUDA_SET_STREAM(s.stream_txt); + if (!flux_cuda_adaln_norm_device(s.d_txt_norm, s.d_txt_hidden, s.d_txt_shift2, s.d_txt_scale2, + txt_seq, hidden, eps)) goto fail; + if (cudaStreamSynchronize(s.stream_img) != cudaSuccess) goto fail; + if (cudaStreamSynchronize(s.stream_txt) != cudaSuccess) goto fail; + + CUDA_SET_STREAM(s.stream_img); + if (!flux_cuda_linear_nobias_device(s.d_img_gate_ffn, s.d_img_norm, b->img_mlp_gate_weight, + img_seq, hidden, mlp_hidden)) goto fail; + if (!flux_cuda_linear_nobias_device(s.d_img_up_ffn, s.d_img_norm, b->img_mlp_up_weight, + img_seq, hidden, mlp_hidden)) goto fail; + CUDA_SET_STREAM(s.stream_txt); + if (!flux_cuda_linear_nobias_device(s.d_txt_gate_ffn, s.d_txt_norm, b->txt_mlp_gate_weight, + txt_seq, hidden, mlp_hidden)) goto fail; + if (!flux_cuda_linear_nobias_device(s.d_txt_up_ffn, s.d_txt_norm, b->txt_mlp_up_weight, + txt_seq, hidden, mlp_hidden)) goto fail; + if (cudaStreamSynchronize(s.stream_img) != cudaSuccess) goto fail; + if (cudaStreamSynchronize(s.stream_txt) != cudaSuccess) goto fail; + + CUDA_SET_STREAM(s.stream_img); + if (!flux_cuda_silu_mul_device(s.d_img_gate_ffn, s.d_img_up_ffn, img_seq * mlp_hidden)) goto fail; + CUDA_SET_STREAM(s.stream_txt); + if (!flux_cuda_silu_mul_device(s.d_txt_gate_ffn, s.d_txt_up_ffn, txt_seq * mlp_hidden)) goto fail; + if (cudaStreamSynchronize(s.stream_img) != cudaSuccess) goto fail; + if (cudaStreamSynchronize(s.stream_txt) != cudaSuccess) goto fail; + + CUDA_SET_STREAM(s.stream_img); + if (!flux_cuda_linear_nobias_device(s.d_img_down, s.d_img_gate_ffn, b->img_mlp_down_weight, + img_seq, mlp_hidden, hidden)) goto fail; + if (!flux_cuda_gated_add_device(s.d_img_hidden, s.d_img_gate2, s.d_img_down, img_seq, hidden)) goto fail; + CUDA_SET_STREAM(s.stream_txt); + if (!flux_cuda_linear_nobias_device(s.d_txt_down, s.d_txt_gate_ffn, b->txt_mlp_down_weight, + txt_seq, mlp_hidden, hidden)) goto fail; + if (!flux_cuda_gated_add_device(s.d_txt_hidden, s.d_txt_gate2, s.d_txt_down, txt_seq, hidden)) goto fail; + if (cudaStreamSynchronize(s.stream_img) != cudaSuccess) goto fail; + if (cudaStreamSynchronize(s.stream_txt) != cudaSuccess) goto fail; + } else { + CUDA_SET_STREAM(s.stream_main); + if (!flux_cuda_linear_nobias_device(s.d_img_proj, s.d_img_attn_out, b->img_proj_weight, + img_seq, hidden, hidden)) goto fail; + if (!flux_cuda_linear_nobias_device(s.d_txt_proj, s.d_txt_attn_out, b->txt_proj_weight, + txt_seq, hidden, hidden)) goto fail; + if (!flux_cuda_gated_add_device(s.d_img_hidden, s.d_img_gate1, s.d_img_proj, img_seq, hidden)) goto fail; + if (!flux_cuda_gated_add_device(s.d_txt_hidden, s.d_txt_gate1, s.d_txt_proj, txt_seq, hidden)) goto fail; + + if (!flux_cuda_adaln_norm_device(s.d_img_norm, s.d_img_hidden, s.d_img_shift2, s.d_img_scale2, + img_seq, hidden, eps)) goto fail; + if (!flux_cuda_adaln_norm_device(s.d_txt_norm, s.d_txt_hidden, s.d_txt_shift2, s.d_txt_scale2, + txt_seq, hidden, eps)) goto fail; + + if (!flux_cuda_linear_nobias_device(s.d_img_gate_ffn, s.d_img_norm, b->img_mlp_gate_weight, + img_seq, hidden, mlp_hidden)) goto fail; + if (!flux_cuda_linear_nobias_device(s.d_img_up_ffn, s.d_img_norm, b->img_mlp_up_weight, + img_seq, hidden, mlp_hidden)) goto fail; + if (!flux_cuda_linear_nobias_device(s.d_txt_gate_ffn, s.d_txt_norm, b->txt_mlp_gate_weight, + txt_seq, hidden, mlp_hidden)) goto fail; + if (!flux_cuda_linear_nobias_device(s.d_txt_up_ffn, s.d_txt_norm, b->txt_mlp_up_weight, + txt_seq, hidden, mlp_hidden)) goto fail; + + if (!flux_cuda_silu_mul_device(s.d_img_gate_ffn, s.d_img_up_ffn, img_seq * mlp_hidden)) goto fail; + if (!flux_cuda_silu_mul_device(s.d_txt_gate_ffn, s.d_txt_up_ffn, txt_seq * mlp_hidden)) goto fail; + + if (!flux_cuda_linear_nobias_device(s.d_img_down, s.d_img_gate_ffn, b->img_mlp_down_weight, + img_seq, mlp_hidden, hidden)) goto fail; + if (!flux_cuda_linear_nobias_device(s.d_txt_down, s.d_txt_gate_ffn, b->txt_mlp_down_weight, + txt_seq, mlp_hidden, hidden)) goto fail; + + if (!flux_cuda_gated_add_device(s.d_img_hidden, s.d_img_gate2, s.d_img_down, img_seq, hidden)) goto fail; + if (!flux_cuda_gated_add_device(s.d_txt_hidden, s.d_txt_gate2, s.d_txt_down, txt_seq, hidden)) goto fail; + } + + if (flux_substep_callback) + flux_substep_callback(FLUX_SUBSTEP_DOUBLE_BLOCK, i, tf->num_double_layers); + } + + CUDA_SET_STREAM(s.stream_main); + if (cudaStreamSynchronize(s.stream_main) != cudaSuccess) goto fail; + CUDA_D2H(img_hidden, s.d_img_hidden, (size_t)img_seq * hidden); + CUDA_D2H(txt_hidden, s.d_txt_hidden, (size_t)txt_seq * hidden); + ok = 1; + +fail: + flux_cuda_linear_set_stream(NULL); + flux_cuda_ops_set_stream(NULL); + return ok; + +#undef CUDA_FREE_PTR +#undef CUDA_ALLOC_PTR +#undef CUDA_H2D +#undef CUDA_D2H +#undef CUDA_SET_STREAM +} +#endif /* USE_CUDA */ + /* ======================================================================== * Single-Stream Block (Parallel DiT) * ======================================================================== */ @@ -3176,7 +3368,7 @@ static float *flux_transformer_forward_bf16(flux_transformer_t *tf, /* Double-stream blocks */ for (int i = 0; i < tf->num_double_layers; i++) { if (tf->use_mmap) { - load_double_block_weights(&tf->double_blocks[i], tf->sf_files, tf->num_sf_files, i, + load_double_block_weights(&tf->double_blocks[i], tf->sf, i, tf->hidden_size, tf->mlp_hidden, tf->use_bf16); } if (!double_block_forward_bf16(img_hidden, txt_hidden, @@ -3205,7 +3397,7 @@ static float *flux_transformer_forward_bf16(flux_transformer_t *tf, /* Single-stream blocks */ for (int i = 0; i < tf->num_single_layers; i++) { if (tf->use_mmap) { - load_single_block_weights(&tf->single_blocks[i], tf->sf_files, tf->num_sf_files, i, + load_single_block_weights(&tf->single_blocks[i], tf->sf, i, tf->hidden_size, tf->mlp_hidden, tf->use_bf16); } @@ -3330,11 +3522,11 @@ static float *flux_transformer_forward_bf16(flux_transformer_t *tf, } #endif /* USE_METAL */ -static void single_block_forward(float *hidden, const single_block_t *block, - const float *t_emb, const float *adaln_weight, - const float *img_rope_cos, const float *img_rope_sin, - const float *txt_rope_cos, const float *txt_rope_sin, - int seq, int img_offset, flux_transformer_t *tf) { +static void single_block_forward_precomputed(float *hidden, const single_block_t *block, + const float *shift, const float *scale, const float *gate, + const float *img_rope_cos, const float *img_rope_sin, + const float *txt_rope_cos, const float *txt_rope_sin, + int seq, int img_offset, flux_transformer_t *tf) { /* seq = total_seq (txt + img) * img_offset = txt_seq (where image starts in the [txt, img] concatenation) */ @@ -3346,34 +3538,9 @@ static void single_block_forward(float *hidden, const single_block_t *block, int img_seq = seq - img_offset; /* Number of image tokens */ float eps = 1e-6f; - /* Compute AdaLN parameters (3: shift, scale, gate) - * adaln_weight is [hidden*3, hidden], t_emb is [hidden] - * FLUX applies SiLU to t_emb before the modulation projection - */ - int mod_size = h_size * 3; - double _t0 = prof_get_time(); - - /* Apply SiLU to t_emb for modulation - use pre-allocated buffer */ - float *t_emb_silu = tf->t_emb_silu; - for (int i = 0; i < h_size; i++) { - float x = t_emb[i]; - t_emb_silu[i] = x / (1.0f + expf(-x)); - } - - /* Use end of work2 for mod_params (3*hidden = 9216 floats) - * fused_out uses seq * fused_dim floats, place mod_params after */ - float *mod_params = tf->work2 + seq * fused_dim; - flux_linear_nobias(mod_params, t_emb_silu, adaln_weight, 1, h_size, mod_size); - - float *shift = mod_params; - float *scale = mod_params + h_size; - float *gate = mod_params + h_size * 2; - /* Norm */ float *norm = tf->work1; apply_adaln(norm, hidden, shift, scale, seq, h_size, eps); - double _t1 = prof_get_time(); - prof_single_adaln += _t1 - _t0; /* Fused QKV + FFN input projection * Output: [seq, fused_dim] where fused_dim = [Q, K, V, gate, up] @@ -3382,8 +3549,6 @@ static void single_block_forward(float *hidden, const single_block_t *block, float *fused_out = tf->work2; LINEAR_BF16_OR_F32(fused_out, norm, block->qkv_mlp_weight, block->qkv_mlp_weight_bf16, seq, h_size, fused_dim); - double _t2 = prof_get_time(); - prof_single_fused_matmul += _t2 - _t1; /* Split outputs: use pre-allocated buffers * Each position has [Q, K, V, gate, up] concatenated @@ -3403,9 +3568,6 @@ static void single_block_forward(float *hidden, const single_block_t *block, memcpy(mlp_up + s * mlp_hidden, row + h_size * 3 + mlp_hidden, mlp_hidden * sizeof(float)); } - double _t3 = prof_get_time(); - prof_single_split += _t3 - _t2; - /* Apply QK normalization */ apply_qk_norm(q, k, block->norm_q_weight, block->norm_k_weight, seq, heads, head_dim, eps); @@ -3427,20 +3589,13 @@ static void single_block_forward(float *hidden, const single_block_t *block, apply_rope_2d(img_q, img_rope_cos, img_rope_sin, img_seq, heads, head_dim, axis_dim); apply_rope_2d(img_k, img_rope_cos, img_rope_sin, img_seq, heads, head_dim, axis_dim); - double _t4 = prof_get_time(); - prof_single_qknorm_rope += _t4 - _t3; - /* Self-attention - use pre-allocated buffer */ float *attn_out = tf->single_attn_out; mha_forward(attn_out, q, k, v, seq, heads, head_dim, tf); - double _t5 = prof_get_time(); - prof_single_attention += _t5 - _t4; - - /* SwiGLU: silu(gate) * up - fused for better performance */ - flux_silu_mul(mlp_gate, mlp_up, seq * mlp_hidden); - double _t6 = prof_get_time(); - prof_single_swiglu += _t6 - _t5; + /* SwiGLU: silu(gate) * up */ + flux_silu(mlp_gate, seq * mlp_hidden); + flux_mul_inplace(mlp_gate, mlp_up, seq * mlp_hidden); /* Fused output projection: [attn_out, mlp_out] -> hidden * proj_mlp_weight: [hidden, hidden + mlp_hidden] @@ -3458,17 +3613,368 @@ static void single_block_forward(float *hidden, const single_block_t *block, LINEAR_BF16_OR_F32(proj_out, concat, block->proj_mlp_weight, block->proj_mlp_weight_bf16, seq, h_size + mlp_hidden, h_size); - double _t7 = prof_get_time(); - prof_single_proj_matmul += _t7 - _t6; - /* Apply gate and add residual - use vectorized helper */ gated_add(hidden, gate, proj_out, seq, h_size); - double _t8 = prof_get_time(); - prof_single_gated_add += _t8 - _t7; /* No free - using pre-allocated buffers */ } +#ifdef USE_CUDA +/* CUDA-resident single-stream path: + * Keeps hidden activations on GPU across all single blocks and executes + * block internals with CUDA kernels + cuBLAS device GEMM. + * Optionally runs the final AdaLN + projection on GPU and returns output in + * [img_seq, latent_channels] (NLC) host layout. + */ +static int single_blocks_forward_cuda_resident(float *hidden, + flux_transformer_t *tf, + const float *shift, const float *scale, const float *gate, + const float *img_rope_cos, const float *img_rope_sin, + const float *txt_rope_cos, const float *txt_rope_sin, + int seq, int img_offset, + const float *final_shift, const float *final_scale, + const float *final_proj_weight, int latent_channels, + float *final_out_nlc) { + if (!hidden || !tf || !shift || !scale || !gate || + !img_rope_cos || !img_rope_sin || !txt_rope_cos || !txt_rope_sin) { + return 0; + } + if (tf->use_mmap || tf->use_bf16) { + return 0; + } + if (final_out_nlc && + (!final_shift || !final_scale || !final_proj_weight || latent_channels <= 0)) { + return 0; + } + + int hidden_size = tf->hidden_size; + int heads = tf->num_heads; + int head_dim = tf->head_dim; + int mlp_hidden = tf->mlp_hidden; + int fused_dim = hidden_size * 3 + mlp_hidden * 2; + int img_seq = seq - img_offset; + float eps = 1e-6f; + float attn_scale = 1.0f / sqrtf((float)head_dim); + + if (seq <= 0 || img_seq <= 0) return 0; + + typedef struct { + const flux_transformer_t *tf; + int seq, img_offset, img_seq; + const float *host_txt_cos, *host_txt_sin; + const float *host_img_cos, *host_img_sin; + cudaStream_t stream_main; + float *d_hidden, *d_norm, *d_fused; + float *d_q, *d_k, *d_v; + float *d_gate_mlp, *d_up, *d_attn_out; + float *d_concat, *d_proj; + float *d_shift, *d_scale, *d_gate; + float *d_txt_cos, *d_txt_sin, *d_img_cos, *d_img_sin; + float *d_norm_q_all, *d_norm_k_all; + float *d_final_shift, *d_final_scale; + float *d_final_norm, *d_final_out; + cudaGraph_t graph; + cudaGraphExec_t graph_exec; + int graph_valid; + int graph_has_final; + int graph_seq; + int graph_img_offset; + int graph_latent_channels; + } cuda_single_ctx_t; + + static cuda_single_ctx_t s = {0}; + int ok = 0; + int use_graph = getenv("FLUX_CUDA_GRAPH_SINGLE") ? 1 : 0; + int has_final = final_out_nlc ? 1 : 0; + +#define CUDA_FREE_PTR(ptr) do { if (ptr) { cudaFree(ptr); ptr = NULL; } } while (0) +#define CUDA_ALLOC_PTR(ptr, elems) \ + do { \ + if (cudaMalloc((void **)&(ptr), (size_t)(elems) * sizeof(float)) != cudaSuccess) goto fail; \ + } while (0) +#define CUDA_H2D(dst, src, elems) \ + do { \ + if (cudaMemcpy((dst), (src), (size_t)(elems) * sizeof(float), cudaMemcpyHostToDevice) != cudaSuccess) goto fail; \ + } while (0) +#define CUDA_D2H(dst, src, elems) \ + do { \ + if (cudaMemcpy((dst), (src), (size_t)(elems) * sizeof(float), cudaMemcpyDeviceToHost) != cudaSuccess) goto fail; \ + } while (0) +#define CUDA_SET_STREAM(stream_) \ + do { \ + if (!flux_cuda_linear_set_stream((void *)(stream_))) goto fail; \ + if (!flux_cuda_ops_set_stream((void *)(stream_))) goto fail; \ + } while (0) + +#define RUN_SINGLE_LOOP(EMIT_CALLBACKS) \ + do { \ + for (int i = 0; i < tf->num_single_layers; i++) { \ + const single_block_t *b = &tf->single_blocks[i]; \ + const float *d_norm_q = s.d_norm_q_all + (size_t)i * head_dim; \ + const float *d_norm_k = s.d_norm_k_all + (size_t)i * head_dim; \ + if (!b->qkv_mlp_weight || !b->proj_mlp_weight) goto fail; \ + if (!flux_cuda_adaln_norm_device(s.d_norm, s.d_hidden, s.d_shift, s.d_scale, seq, hidden_size, eps)) goto fail; \ + if (!flux_cuda_linear_nobias_device(s.d_fused, s.d_norm, b->qkv_mlp_weight, \ + seq, hidden_size, fused_dim)) goto fail; \ + if (!flux_cuda_split_qkv_mlp_device(s.d_fused, s.d_q, s.d_k, s.d_v, s.d_gate_mlp, s.d_up, \ + seq, hidden_size, mlp_hidden)) goto fail; \ + if (!flux_cuda_qk_rms_norm_device(s.d_q, s.d_k, d_norm_q, d_norm_k, \ + seq, heads, head_dim, eps)) goto fail; \ + if (!flux_cuda_rope_unified_device(s.d_q, s.d_k, s.d_txt_cos, s.d_txt_sin, s.d_img_cos, s.d_img_sin, \ + seq, img_offset, heads, head_dim)) goto fail; \ + if (!flux_cuda_attention_batched_shd_device(s.d_attn_out, s.d_q, s.d_k, s.d_v, \ + heads, seq, seq, head_dim, \ + attn_scale, 0, NULL)) goto fail; \ + if (!flux_cuda_silu_mul_device(s.d_gate_mlp, s.d_up, seq * mlp_hidden)) goto fail; \ + if (!flux_cuda_concat_attn_mlp_device(s.d_attn_out, s.d_gate_mlp, s.d_concat, \ + seq, hidden_size, mlp_hidden)) goto fail; \ + if (!flux_cuda_linear_nobias_device(s.d_proj, s.d_concat, b->proj_mlp_weight, \ + seq, hidden_size + mlp_hidden, hidden_size)) goto fail; \ + if (!flux_cuda_gated_add_device(s.d_hidden, s.d_gate, s.d_proj, seq, hidden_size)) goto fail; \ + if ((EMIT_CALLBACKS) && flux_substep_callback) { \ + flux_substep_callback(FLUX_SUBSTEP_SINGLE_BLOCK, i, tf->num_single_layers); \ + } \ + } \ + } while (0) + + int need_realloc = (!s.d_hidden || + s.tf != tf || + s.seq != seq || + s.img_offset != img_offset); + if (need_realloc) { + if (s.graph_exec) { cudaGraphExecDestroy(s.graph_exec); s.graph_exec = NULL; } + if (s.graph) { cudaGraphDestroy(s.graph); s.graph = NULL; } + s.graph_valid = 0; + + CUDA_FREE_PTR(s.d_hidden); + CUDA_FREE_PTR(s.d_norm); + CUDA_FREE_PTR(s.d_fused); + CUDA_FREE_PTR(s.d_q); + CUDA_FREE_PTR(s.d_k); + CUDA_FREE_PTR(s.d_v); + CUDA_FREE_PTR(s.d_gate_mlp); + CUDA_FREE_PTR(s.d_up); + CUDA_FREE_PTR(s.d_attn_out); + CUDA_FREE_PTR(s.d_concat); + CUDA_FREE_PTR(s.d_proj); + CUDA_FREE_PTR(s.d_shift); + CUDA_FREE_PTR(s.d_scale); + CUDA_FREE_PTR(s.d_gate); + CUDA_FREE_PTR(s.d_txt_cos); + CUDA_FREE_PTR(s.d_txt_sin); + CUDA_FREE_PTR(s.d_img_cos); + CUDA_FREE_PTR(s.d_img_sin); + CUDA_FREE_PTR(s.d_norm_q_all); + CUDA_FREE_PTR(s.d_norm_k_all); + CUDA_FREE_PTR(s.d_final_shift); + CUDA_FREE_PTR(s.d_final_scale); + CUDA_FREE_PTR(s.d_final_norm); + CUDA_FREE_PTR(s.d_final_out); + + CUDA_ALLOC_PTR(s.d_hidden, (size_t)seq * hidden_size); + CUDA_ALLOC_PTR(s.d_norm, (size_t)seq * hidden_size); + CUDA_ALLOC_PTR(s.d_fused, (size_t)seq * fused_dim); + CUDA_ALLOC_PTR(s.d_q, (size_t)seq * hidden_size); + CUDA_ALLOC_PTR(s.d_k, (size_t)seq * hidden_size); + CUDA_ALLOC_PTR(s.d_v, (size_t)seq * hidden_size); + CUDA_ALLOC_PTR(s.d_gate_mlp, (size_t)seq * mlp_hidden); + CUDA_ALLOC_PTR(s.d_up, (size_t)seq * mlp_hidden); + CUDA_ALLOC_PTR(s.d_attn_out, (size_t)seq * hidden_size); + CUDA_ALLOC_PTR(s.d_concat, (size_t)seq * (hidden_size + mlp_hidden)); + CUDA_ALLOC_PTR(s.d_proj, (size_t)seq * hidden_size); + CUDA_ALLOC_PTR(s.d_shift, hidden_size); + CUDA_ALLOC_PTR(s.d_scale, hidden_size); + CUDA_ALLOC_PTR(s.d_gate, hidden_size); + CUDA_ALLOC_PTR(s.d_txt_cos, (size_t)img_offset * head_dim); + CUDA_ALLOC_PTR(s.d_txt_sin, (size_t)img_offset * head_dim); + CUDA_ALLOC_PTR(s.d_img_cos, (size_t)img_seq * head_dim); + CUDA_ALLOC_PTR(s.d_img_sin, (size_t)img_seq * head_dim); + CUDA_ALLOC_PTR(s.d_norm_q_all, (size_t)tf->num_single_layers * head_dim); + CUDA_ALLOC_PTR(s.d_norm_k_all, (size_t)tf->num_single_layers * head_dim); + CUDA_ALLOC_PTR(s.d_final_shift, hidden_size); + CUDA_ALLOC_PTR(s.d_final_scale, hidden_size); + CUDA_ALLOC_PTR(s.d_final_norm, (size_t)img_seq * hidden_size); + CUDA_ALLOC_PTR(s.d_final_out, (size_t)img_seq * tf->latent_channels); + + for (int i = 0; i < tf->num_single_layers; i++) { + const single_block_t *b = &tf->single_blocks[i]; + if (!b->norm_q_weight || !b->norm_k_weight) goto fail; + CUDA_H2D(s.d_norm_q_all + (size_t)i * head_dim, b->norm_q_weight, head_dim); + CUDA_H2D(s.d_norm_k_all + (size_t)i * head_dim, b->norm_k_weight, head_dim); + } + + if (!s.stream_main) { + if (cudaStreamCreateWithFlags(&s.stream_main, cudaStreamNonBlocking) != cudaSuccess) goto fail; + } + + s.tf = tf; + s.seq = seq; + s.img_offset = img_offset; + s.img_seq = img_seq; + s.host_txt_cos = NULL; + s.host_txt_sin = NULL; + s.host_img_cos = NULL; + s.host_img_sin = NULL; + } + + if (latent_channels > tf->latent_channels) goto fail; + + CUDA_H2D(s.d_hidden, hidden, (size_t)seq * hidden_size); + CUDA_H2D(s.d_shift, shift, hidden_size); + CUDA_H2D(s.d_scale, scale, hidden_size); + CUDA_H2D(s.d_gate, gate, hidden_size); + if (s.host_txt_cos != txt_rope_cos || s.host_txt_sin != txt_rope_sin) { + CUDA_H2D(s.d_txt_cos, txt_rope_cos, (size_t)img_offset * head_dim); + CUDA_H2D(s.d_txt_sin, txt_rope_sin, (size_t)img_offset * head_dim); + s.host_txt_cos = txt_rope_cos; + s.host_txt_sin = txt_rope_sin; + } + if (s.host_img_cos != img_rope_cos || s.host_img_sin != img_rope_sin) { + CUDA_H2D(s.d_img_cos, img_rope_cos, (size_t)img_seq * head_dim); + CUDA_H2D(s.d_img_sin, img_rope_sin, (size_t)img_seq * head_dim); + s.host_img_cos = img_rope_cos; + s.host_img_sin = img_rope_sin; + } + if (has_final) { + if (!final_shift || !final_scale || !final_proj_weight) goto fail; + CUDA_H2D(s.d_final_shift, final_shift, hidden_size); + CUDA_H2D(s.d_final_scale, final_scale, hidden_size); + } + + CUDA_SET_STREAM(s.stream_main); + + if (use_graph) { + int graph_match = s.graph_valid && + s.graph_has_final == has_final && + s.graph_seq == seq && + s.graph_img_offset == img_offset && + s.graph_latent_channels == latent_channels; + if (!graph_match) { + if (s.graph_exec) { cudaGraphExecDestroy(s.graph_exec); s.graph_exec = NULL; } + if (s.graph) { cudaGraphDestroy(s.graph); s.graph = NULL; } + s.graph_valid = 0; + + if (!flux_cuda_attention_batched_shd_device(s.d_attn_out, s.d_q, s.d_k, s.d_v, + heads, seq, seq, head_dim, + attn_scale, 0, NULL)) { + use_graph = 0; + } else if (cudaStreamSynchronize(s.stream_main) != cudaSuccess) { + use_graph = 0; + } else if (cudaStreamBeginCapture(s.stream_main, cudaStreamCaptureModeRelaxed) == cudaSuccess) { + int cap_ok = 1; + for (int i = 0; i < tf->num_single_layers; i++) { + const single_block_t *b = &tf->single_blocks[i]; + const float *d_norm_q = s.d_norm_q_all + (size_t)i * head_dim; + const float *d_norm_k = s.d_norm_k_all + (size_t)i * head_dim; + if (!b->qkv_mlp_weight || !b->proj_mlp_weight) { cap_ok = 0; break; } + if (!flux_cuda_adaln_norm_device(s.d_norm, s.d_hidden, s.d_shift, s.d_scale, seq, hidden_size, eps)) { cap_ok = 0; break; } + if (!flux_cuda_linear_nobias_device(s.d_fused, s.d_norm, b->qkv_mlp_weight, + seq, hidden_size, fused_dim)) { cap_ok = 0; break; } + if (!flux_cuda_split_qkv_mlp_device(s.d_fused, s.d_q, s.d_k, s.d_v, s.d_gate_mlp, s.d_up, + seq, hidden_size, mlp_hidden)) { cap_ok = 0; break; } + if (!flux_cuda_qk_rms_norm_device(s.d_q, s.d_k, d_norm_q, d_norm_k, + seq, heads, head_dim, eps)) { cap_ok = 0; break; } + if (!flux_cuda_rope_unified_device(s.d_q, s.d_k, s.d_txt_cos, s.d_txt_sin, s.d_img_cos, s.d_img_sin, + seq, img_offset, heads, head_dim)) { cap_ok = 0; break; } + if (!flux_cuda_attention_batched_shd_device(s.d_attn_out, s.d_q, s.d_k, s.d_v, + heads, seq, seq, head_dim, + attn_scale, 0, NULL)) { cap_ok = 0; break; } + if (!flux_cuda_silu_mul_device(s.d_gate_mlp, s.d_up, seq * mlp_hidden)) { cap_ok = 0; break; } + if (!flux_cuda_concat_attn_mlp_device(s.d_attn_out, s.d_gate_mlp, s.d_concat, + seq, hidden_size, mlp_hidden)) { cap_ok = 0; break; } + if (!flux_cuda_linear_nobias_device(s.d_proj, s.d_concat, b->proj_mlp_weight, + seq, hidden_size + mlp_hidden, hidden_size)) { cap_ok = 0; break; } + if (!flux_cuda_gated_add_device(s.d_hidden, s.d_gate, s.d_proj, seq, hidden_size)) { cap_ok = 0; break; } + } + if (cap_ok && has_final) { + const float *d_img_hidden = s.d_hidden + (size_t)img_offset * hidden_size; + if (!flux_cuda_adaln_norm_device(s.d_final_norm, d_img_hidden, s.d_final_shift, s.d_final_scale, + img_seq, hidden_size, eps)) cap_ok = 0; + if (cap_ok && !flux_cuda_linear_nobias_device(s.d_final_out, s.d_final_norm, final_proj_weight, + img_seq, hidden_size, latent_channels)) cap_ok = 0; + } + + cudaGraph_t cap_graph = NULL; + if (cudaStreamEndCapture(s.stream_main, &cap_graph) != cudaSuccess || !cap_ok || !cap_graph) { + if (cap_graph) cudaGraphDestroy(cap_graph); + use_graph = 0; + } else if (cudaGraphInstantiate(&s.graph_exec, cap_graph, 0) != cudaSuccess) { + cudaGraphDestroy(cap_graph); + use_graph = 0; + } else { + s.graph = cap_graph; + s.graph_valid = 1; + s.graph_has_final = has_final; + s.graph_seq = seq; + s.graph_img_offset = img_offset; + s.graph_latent_channels = latent_channels; + } + } else { + use_graph = 0; + } + } + + if (use_graph && s.graph_valid) { + if (cudaGraphLaunch(s.graph_exec, s.stream_main) != cudaSuccess) { + use_graph = 0; + s.graph_valid = 0; + } + } + } + + if (!use_graph) { + RUN_SINGLE_LOOP(1); + if (has_final) { + const float *d_img_hidden = s.d_hidden + (size_t)img_offset * hidden_size; + if (!flux_cuda_adaln_norm_device(s.d_final_norm, d_img_hidden, s.d_final_shift, s.d_final_scale, + img_seq, hidden_size, eps)) goto fail; + if (!flux_cuda_linear_nobias_device(s.d_final_out, s.d_final_norm, final_proj_weight, + img_seq, hidden_size, latent_channels)) goto fail; + } + } + + if (cudaStreamSynchronize(s.stream_main) != cudaSuccess) goto fail; + if (use_graph && flux_substep_callback) { + for (int i = 0; i < tf->num_single_layers; i++) { + flux_substep_callback(FLUX_SUBSTEP_SINGLE_BLOCK, i, tf->num_single_layers); + } + } + + if (has_final) { + CUDA_D2H(final_out_nlc, s.d_final_out, (size_t)img_seq * latent_channels); + } else { + CUDA_D2H(hidden, s.d_hidden, (size_t)seq * hidden_size); + } + ok = 1; + +fail: + flux_cuda_linear_set_stream(NULL); + flux_cuda_ops_set_stream(NULL); + return ok; + +#undef CUDA_FREE_PTR +#undef CUDA_ALLOC_PTR +#undef CUDA_H2D +#undef CUDA_D2H +#undef CUDA_SET_STREAM +#undef RUN_SINGLE_LOOP +} +#endif /* USE_CUDA */ + +static void compute_single_mod_params(flux_transformer_t *tf, + const float *t_emb, + const float *adaln_weight) { + int hidden = tf->hidden_size; + int mod_size = hidden * 3; + + for (int i = 0; i < hidden; i++) { + float x = t_emb[i]; + tf->t_emb_silu[i] = x / (1.0f + expf(-x)); + } + flux_linear_nobias(tf->single_mod_params, tf->t_emb_silu, + adaln_weight, 1, hidden, mod_size); +} + /* ======================================================================== * Full Transformer Forward Pass * ======================================================================== */ @@ -3581,6 +4087,7 @@ float *flux_transformer_forward(flux_transformer_t *tf, /* Double-stream blocks */ double double_start = tf_get_time_ms(); + int cuda_resident_double_ok = 0; /* Pre-compute AdaLN modulation ONCE for all 5 double blocks. * t_emb and adaln weights are the same for all blocks within a step. */ @@ -3594,38 +4101,54 @@ float *flux_transformer_forward(flux_transformer_t *tf, flux_linear_nobias(tf->double_mod_txt, tf->t_emb_silu, tf->adaln_double_txt_weight, 1, hidden, double_mod_size); - for (int i = 0; i < tf->num_double_layers; i++) { - /* In mmap mode, load block weights on-demand and free after use */ - if (tf->use_mmap && tf->double_blocks[i].img_q_weight == NULL - && tf->double_blocks[i].img_q_weight_bf16 == NULL) { - load_double_block_weights(&tf->double_blocks[i], tf->sf_files, tf->num_sf_files, i, - tf->hidden_size, tf->mlp_hidden, tf->use_bf16); - } - double_block_forward(img_hidden, txt_hidden, - &tf->double_blocks[i], - tf->double_mod_img, tf->double_mod_txt, - img_rope_cos, img_rope_sin, - txt_rope_cos, txt_rope_sin, - img_seq, txt_seq, tf); - if (tf->use_mmap) free_double_block_weights(&tf->double_blocks[i]); - if (flux_substep_callback) - flux_substep_callback(FLUX_SUBSTEP_DOUBLE_BLOCK, i, tf->num_double_layers); +#ifdef USE_CUDA + if (!getenv("FLUX_CUDA_NO_RESIDENT_DOUBLE")) { + cuda_resident_double_ok = double_blocks_forward_cuda_resident(img_hidden, txt_hidden, + tf, + tf->double_mod_img, tf->double_mod_txt, + img_rope_cos, img_rope_sin, + txt_rope_cos, txt_rope_sin, + img_seq, txt_seq); + } +#endif + + if (!cuda_resident_double_ok) { + for (int i = 0; i < tf->num_double_layers; i++) { + /* In mmap mode, load block weights on-demand */ + if (tf->use_mmap) { + load_double_block_weights(&tf->double_blocks[i], tf->sf, i, + tf->hidden_size, tf->mlp_hidden, tf->use_bf16); + } + double_block_forward(img_hidden, txt_hidden, + &tf->double_blocks[i], + tf->double_mod_img, tf->double_mod_txt, + img_rope_cos, img_rope_sin, + txt_rope_cos, txt_rope_sin, + img_seq, txt_seq, tf); + /* In mmap mode, free block weights after use */ + if (tf->use_mmap) { + free_double_block_weights(&tf->double_blocks[i]); + /* With direct mmap pointers for bf16, no need to clear caches. */ + } + if (flux_substep_callback) + flux_substep_callback(FLUX_SUBSTEP_DOUBLE_BLOCK, i, tf->num_double_layers); #ifdef DEBUG_TRANSFORMER - if (i == 0) { - fprintf(stderr, "\n[DEBUG] After double block 0:\n"); - fprintf(stderr, "[DEBUG] img_hidden[0,0,:10]: "); - for (int d = 0; d < 10; d++) fprintf(stderr, "%.6f ", img_hidden[d]); - fprintf(stderr, "\n"); - float sum = 0, sum_sq = 0; - for (int d = 0; d < img_seq * hidden; d++) { - sum += img_hidden[d]; - sum_sq += img_hidden[d] * img_hidden[d]; + if (i == 0) { + fprintf(stderr, "\n[DEBUG] After double block 0:\n"); + fprintf(stderr, "[DEBUG] img_hidden[0,0,:10]: "); + for (int d = 0; d < 10; d++) fprintf(stderr, "%.6f ", img_hidden[d]); + fprintf(stderr, "\n"); + float sum = 0, sum_sq = 0; + for (int d = 0; d < img_seq * hidden; d++) { + sum += img_hidden[d]; + sum_sq += img_hidden[d] * img_hidden[d]; + } + float mean = sum / (img_seq * hidden); + float std = sqrtf(sum_sq / (img_seq * hidden) - mean * mean); + fprintf(stderr, "[DEBUG] img_hidden mean=%.6f, std=%.6f\n", mean, std); } - float mean = sum / (img_seq * hidden); - float std = sqrtf(sum_sq / (img_seq * hidden) - mean * mean); - fprintf(stderr, "[DEBUG] img_hidden mean=%.6f, std=%.6f\n", mean, std); - } #endif + } } double double_time = tf_get_time_ms() - double_start; @@ -3640,6 +4163,42 @@ float *flux_transformer_forward(flux_transformer_t *tf, /* Single-stream blocks */ double single_start = tf_get_time_ms(); + int cuda_resident_ok = 0; + float *cuda_final_output_nlc = NULL; + +#ifdef USE_CUDA + if (!getenv("FLUX_CUDA_NO_RESIDENT_SINGLE")) { + compute_single_mod_params(tf, t_emb, tf->adaln_single_weight); + + if (!getenv("FLUX_CUDA_NO_RESIDENT_FINAL")) { + /* Precompute final modulation for optional CUDA-resident final projection. */ + for (int j = 0; j < hidden; j++) { + float x = t_emb[j]; + tf->t_emb_silu[j] = x / (1.0f + expf(-x)); + } + flux_linear_nobias(tf->double_mod_img, tf->t_emb_silu, + tf->final_norm_weight, 1, hidden, hidden * 2); + cuda_final_output_nlc = (float *)malloc((size_t)img_seq * tf->latent_channels * sizeof(float)); + } + + cuda_resident_ok = single_blocks_forward_cuda_resident(concat_hidden, tf, + tf->single_mod_params, + tf->single_mod_params + hidden, + tf->single_mod_params + hidden * 2, + img_rope_cos, img_rope_sin, + txt_rope_cos, txt_rope_sin, + total_seq, txt_seq, + cuda_final_output_nlc ? (tf->double_mod_img + hidden) : NULL, + cuda_final_output_nlc ? tf->double_mod_img : NULL, + cuda_final_output_nlc ? tf->final_proj_weight : NULL, + cuda_final_output_nlc ? tf->latent_channels : 0, + cuda_final_output_nlc); + if (!cuda_resident_ok && cuda_final_output_nlc) { + free(cuda_final_output_nlc); + cuda_final_output_nlc = NULL; + } + } +#endif #ifdef USE_METAL /* Try BF16 native path first */ @@ -3652,7 +4211,7 @@ float *flux_transformer_forward(flux_transformer_t *tf, * MPS SGEMM with f16 weights. To achieve higher bf16 performance, we would need * highly optimized custom Metal matmul kernels. * The f32 path with f16 weights and pre-warmed caches is currently faster. */ - if (0 && flux_metal_available() && flux_bf16_pipeline_available() && !tf->use_mmap && tf->use_bf16) { + if (!cuda_resident_ok && 0 && flux_metal_available() && flux_bf16_pipeline_available() && !tf->use_mmap && tf->use_bf16) { /* Create f32 GPU tensor first, then convert to bf16 */ flux_gpu_tensor_t hidden_f32 = flux_gpu_tensor_create(concat_hidden, total_seq * hidden); if (hidden_f32) { @@ -3748,7 +4307,7 @@ float *flux_transformer_forward(flux_transformer_t *tf, } /* Fall back to f32 GPU-chained path if bf16 path not used or failed */ - if (!bf16_path_ok && flux_metal_available() && flux_metal_shaders_available() && !tf->use_mmap) { + if (!cuda_resident_ok && !bf16_path_ok && flux_metal_available() && flux_metal_shaders_available() && !tf->use_mmap) { /* Create persistent GPU tensor for hidden state */ concat_hidden_gpu = flux_gpu_tensor_create(concat_hidden, total_seq * hidden); if (concat_hidden_gpu) { @@ -3812,15 +4371,23 @@ float *flux_transformer_forward(flux_transformer_t *tf, concat_hidden_gpu = NULL; } } +#endif - /* Fall back to per-block GPU/CPU path if both bf16 and f32 chained paths failed */ - if (!bf16_path_ok && !gpu_chained_ok) { + /* Fall back to per-block GPU/CPU path if GPU-chained paths were not used */ +#ifdef USE_METAL + if (!cuda_resident_ok && !bf16_path_ok && !gpu_chained_ok) { +#else + if (!cuda_resident_ok) { #endif + compute_single_mod_params(tf, t_emb, tf->adaln_single_weight); + float *single_shift = tf->single_mod_params; + float *single_scale = tf->single_mod_params + hidden; + float *single_gate = tf->single_mod_params + hidden * 2; + for (int i = 0; i < tf->num_single_layers; i++) { - /* In mmap mode, load block weights on-demand and free after use */ - if (tf->use_mmap && tf->single_blocks[i].qkv_mlp_weight == NULL - && tf->single_blocks[i].qkv_mlp_weight_bf16 == NULL) { - load_single_block_weights(&tf->single_blocks[i], tf->sf_files, tf->num_sf_files, i, + /* In mmap mode, load block weights on-demand */ + if (tf->use_mmap) { + load_single_block_weights(&tf->single_blocks[i], tf->sf, i, tf->hidden_size, tf->mlp_hidden, tf->use_bf16); } #ifdef USE_METAL @@ -3833,13 +4400,17 @@ float *flux_transformer_forward(flux_transformer_t *tf, #endif { /* Fall back to CPU path */ - single_block_forward(concat_hidden, &tf->single_blocks[i], - t_emb, tf->adaln_single_weight, - img_rope_cos, img_rope_sin, - txt_rope_cos, txt_rope_sin, - total_seq, txt_seq, tf); /* txt_seq is the offset to image */ + single_block_forward_precomputed(concat_hidden, &tf->single_blocks[i], + single_shift, single_scale, single_gate, + img_rope_cos, img_rope_sin, + txt_rope_cos, txt_rope_sin, + total_seq, txt_seq, tf); /* txt_seq is the offset to image */ + } + /* In mmap mode, free block weights after use */ + if (tf->use_mmap) { + free_single_block_weights(&tf->single_blocks[i]); + /* With direct mmap pointers for bf16, no need to clear caches. */ } - if (tf->use_mmap) free_single_block_weights(&tf->single_blocks[i]); if (flux_substep_callback) flux_substep_callback(FLUX_SUBSTEP_SINGLE_BLOCK, i, tf->num_single_layers); @@ -3853,59 +4424,81 @@ float *flux_transformer_forward(flux_transformer_t *tf, } #endif } -#ifdef USE_METAL } -#endif double single_time = tf_get_time_ms() - single_start; + double final_start = tf_get_time_ms(); + float *output = NULL; - /* Extract image hidden states (image is after text) */ - memcpy(img_hidden, concat_hidden + txt_seq * hidden, img_seq * hidden * sizeof(float)); - free(concat_hidden); + if (cuda_resident_ok && cuda_final_output_nlc) { + /* Final projection already ran on CUDA in resident path. + * Convert NLC -> NCHW on CPU (small tensor). */ + output = (float *)malloc((size_t)img_seq * tf->latent_channels * sizeof(float)); + if (output) { + for (int pos = 0; pos < img_seq; pos++) { + for (int c = 0; c < channels; c++) { + output[c * img_seq + pos] = cuda_final_output_nlc[pos * channels + c]; + } + } + } + free(cuda_final_output_nlc); + free(concat_hidden); + } else { + if (cuda_final_output_nlc) { + free(cuda_final_output_nlc); + cuda_final_output_nlc = NULL; + } + + /* Extract image hidden states (image is after text). */ + memcpy(img_hidden, concat_hidden + txt_seq * hidden, img_seq * hidden * sizeof(float)); + free(concat_hidden); #ifdef DEBUG_FINAL_LAYER - fprintf(stderr, "[FINAL] Before final layer img_hidden[0,0,:5]: "); - for (int d = 0; d < 5; d++) fprintf(stderr, "%.6f ", img_hidden[d]); - fprintf(stderr, "\n"); + fprintf(stderr, "[FINAL] Before final layer img_hidden[0,0,:5]: "); + for (int d = 0; d < 5; d++) fprintf(stderr, "%.6f ", img_hidden[d]); + fprintf(stderr, "\n"); #endif - /* Final layer: AdaLN modulation -> project to latent channels - * norm_out.linear.weight is [6144, 3072] = [shift, scale] projection - * Apply SiLU to t_emb before modulation projection (FLUX architecture) - */ - double final_start = tf_get_time_ms(); - /* Reuse pre-allocated t_emb_silu buffer */ - for (int i = 0; i < hidden; i++) { - float x = t_emb[i]; - tf->t_emb_silu[i] = x / (1.0f + expf(-x)); - } + /* Final layer: AdaLN modulation -> project to latent channels + * norm_out.linear.weight is [6144, 3072] = [shift, scale] projection + * Apply SiLU to t_emb before modulation projection (FLUX architecture) + */ + for (int i = 0; i < hidden; i++) { + float x = t_emb[i]; + tf->t_emb_silu[i] = x / (1.0f + expf(-x)); + } - /* Reuse double_mod_img buffer for final_mod (needs hidden*2, has hidden*6) */ - float *final_mod = tf->double_mod_img; - flux_linear_nobias(final_mod, tf->t_emb_silu, tf->final_norm_weight, 1, hidden, hidden * 2); + /* Reuse double_mod_img buffer for final_mod (needs hidden*2, has hidden*6) */ + float *final_mod = tf->double_mod_img; + flux_linear_nobias(final_mod, tf->t_emb_silu, tf->final_norm_weight, 1, hidden, hidden * 2); - /* Python: scale, shift = mod.chunk(2, dim=1) - scale is first half, shift is second half */ - float *final_scale = final_mod; - float *final_shift = final_mod + hidden; + /* Python: scale, shift = mod.chunk(2, dim=1) - scale is first half, shift is second half */ + float *final_scale = final_mod; + float *final_shift = final_mod + hidden; - float *final_norm = tf->work1; - apply_adaln(final_norm, img_hidden, final_shift, final_scale, img_seq, hidden, 1e-6f); + float *final_norm = tf->work1; + apply_adaln(final_norm, img_hidden, final_shift, final_scale, img_seq, hidden, 1e-6f); - float *output_nlc = (float *)malloc(img_seq * tf->latent_channels * sizeof(float)); - LINEAR_BF16_OR_F32(output_nlc, final_norm, tf->final_proj_weight, tf->final_proj_weight_bf16, - img_seq, hidden, tf->latent_channels); + float *output_nlc = (float *)malloc((size_t)img_seq * tf->latent_channels * sizeof(float)); + if (output_nlc) { + LINEAR_BF16_OR_F32(output_nlc, final_norm, tf->final_proj_weight, tf->final_proj_weight_bf16, + img_seq, hidden, tf->latent_channels); - /* Transpose output from NLC [seq, channels] to NCHW [channels, h, w] format - * Input: output_nlc[pos * channels + c] - * Output: output[c * img_seq + pos] - */ - float *output = (float *)malloc(img_seq * tf->latent_channels * sizeof(float)); - for (int pos = 0; pos < img_seq; pos++) { - for (int c = 0; c < channels; c++) { - output[c * img_seq + pos] = output_nlc[pos * channels + c]; + /* Transpose output from NLC [seq, channels] to NCHW [channels, h, w] format + * Input: output_nlc[pos * channels + c] + * Output: output[c * img_seq + pos] + */ + output = (float *)malloc((size_t)img_seq * tf->latent_channels * sizeof(float)); + if (output) { + for (int pos = 0; pos < img_seq; pos++) { + for (int c = 0; c < channels; c++) { + output[c * img_seq + pos] = output_nlc[pos * channels + c]; + } + } + } + free(output_nlc); } } - free(output_nlc); free(t_emb); /* RoPE buffers are cached in the transformer and freed in flux_transformer_free(). */ @@ -4064,7 +4657,7 @@ float *flux_transformer_forward_with_refs(flux_transformer_t *tf, /* Double blocks - process combined image with text */ for (int i = 0; i < tf->num_double_layers; i++) { if (tf->use_mmap) { - load_double_block_weights(&tf->double_blocks[i], tf->sf_files, tf->num_sf_files, i, + load_double_block_weights(&tf->double_blocks[i], tf->sf, i, tf->hidden_size, tf->mlp_hidden, tf->use_bf16); } double_block_forward(combined_hidden, txt_hidden, @@ -4088,16 +4681,20 @@ float *flux_transformer_forward_with_refs(flux_transformer_t *tf, free(combined_hidden); /* Single blocks */ + compute_single_mod_params(tf, t_emb, tf->adaln_single_weight); + float *single_shift = tf->single_mod_params; + float *single_scale = tf->single_mod_params + hidden; + float *single_gate = tf->single_mod_params + hidden * 2; for (int i = 0; i < tf->num_single_layers; i++) { if (tf->use_mmap) { - load_single_block_weights(&tf->single_blocks[i], tf->sf_files, tf->num_sf_files, i, + load_single_block_weights(&tf->single_blocks[i], tf->sf, i, tf->hidden_size, tf->mlp_hidden, tf->use_bf16); } - single_block_forward(concat_hidden, &tf->single_blocks[i], - t_emb, tf->adaln_single_weight, - combined_rope_cos, combined_rope_sin, - txt_rope_cos, txt_rope_sin, - total_seq, txt_seq, tf); + single_block_forward_precomputed(concat_hidden, &tf->single_blocks[i], + single_shift, single_scale, single_gate, + combined_rope_cos, combined_rope_sin, + txt_rope_cos, txt_rope_sin, + total_seq, txt_seq, tf); if (tf->use_mmap) { free_single_block_weights(&tf->single_blocks[i]); /* With direct mmap pointers for bf16, no need to clear caches. */ @@ -4308,7 +4905,7 @@ float *flux_transformer_forward_with_multi_refs(flux_transformer_t *tf, /* Double blocks */ for (int i = 0; i < tf->num_double_layers; i++) { if (tf->use_mmap) { - load_double_block_weights(&tf->double_blocks[i], tf->sf_files, tf->num_sf_files, i, + load_double_block_weights(&tf->double_blocks[i], tf->sf, i, tf->hidden_size, tf->mlp_hidden, tf->use_bf16); } double_block_forward(combined_hidden, txt_hidden, @@ -4331,16 +4928,20 @@ float *flux_transformer_forward_with_multi_refs(flux_transformer_t *tf, free(combined_hidden); /* Single blocks */ + compute_single_mod_params(tf, t_emb, tf->adaln_single_weight); + float *single_shift = tf->single_mod_params; + float *single_scale = tf->single_mod_params + hidden; + float *single_gate = tf->single_mod_params + hidden * 2; for (int i = 0; i < tf->num_single_layers; i++) { if (tf->use_mmap) { - load_single_block_weights(&tf->single_blocks[i], tf->sf_files, tf->num_sf_files, i, + load_single_block_weights(&tf->single_blocks[i], tf->sf, i, tf->hidden_size, tf->mlp_hidden, tf->use_bf16); } - single_block_forward(concat_hidden, &tf->single_blocks[i], - t_emb, tf->adaln_single_weight, - combined_rope_cos, combined_rope_sin, - txt_rope_cos, txt_rope_sin, - total_seq, txt_seq, tf); + single_block_forward_precomputed(concat_hidden, &tf->single_blocks[i], + single_shift, single_scale, single_gate, + combined_rope_cos, combined_rope_sin, + txt_rope_cos, txt_rope_sin, + total_seq, txt_seq, tf); if (tf->use_mmap) { free_single_block_weights(&tf->single_blocks[i]); } @@ -4529,8 +5130,9 @@ flux_transformer_t *flux_transformer_load(FILE *f) { tf->t_emb_silu = (float *)malloc(hidden * sizeof(float)); tf->double_mod_img = (float *)malloc(hidden * 6 * sizeof(float)); tf->double_mod_txt = (float *)malloc(hidden * 6 * sizeof(float)); + tf->single_mod_params = (float *)malloc(hidden * 3 * sizeof(float)); - if (!tf->t_emb_silu || !tf->double_mod_img || !tf->double_mod_txt) { + if (!tf->t_emb_silu || !tf->double_mod_img || !tf->double_mod_txt || !tf->single_mod_params) { goto error; } @@ -4544,13 +5146,6 @@ flux_transformer_t *flux_transformer_load(FILE *f) { void flux_transformer_free(flux_transformer_t *tf) { if (!tf) return; - /* In mmap mode, bf16 pointers point into the mmap'd file region and must - * NOT be freed. Clean up any cached block weights first, then only NULL - * the bf16 pointers (don't free them). */ - if (tf->use_mmap) { - flux_transformer_free_mmap_cache(tf); - } - free(tf->img_in_weight); free(tf->txt_in_weight); free(tf->img_in_weight_bf16); @@ -4566,35 +5161,33 @@ void flux_transformer_free(flux_transformer_t *tf) { free(b->img_q_weight); free(b->img_k_weight); free(b->img_v_weight); + free(b->img_q_weight_bf16); + free(b->img_k_weight_bf16); + free(b->img_v_weight_bf16); free(b->img_proj_weight); + free(b->img_proj_weight_bf16); free(b->img_mlp_gate_weight); free(b->img_mlp_up_weight); free(b->img_mlp_down_weight); + free(b->img_mlp_gate_weight_bf16); + free(b->img_mlp_up_weight_bf16); + free(b->img_mlp_down_weight_bf16); free(b->txt_norm_q_weight); free(b->txt_norm_k_weight); free(b->txt_q_weight); free(b->txt_k_weight); free(b->txt_v_weight); + free(b->txt_q_weight_bf16); + free(b->txt_k_weight_bf16); + free(b->txt_v_weight_bf16); free(b->txt_proj_weight); + free(b->txt_proj_weight_bf16); free(b->txt_mlp_gate_weight); free(b->txt_mlp_up_weight); free(b->txt_mlp_down_weight); - if (!tf->use_mmap) { - free(b->img_q_weight_bf16); - free(b->img_k_weight_bf16); - free(b->img_v_weight_bf16); - free(b->img_proj_weight_bf16); - free(b->img_mlp_gate_weight_bf16); - free(b->img_mlp_up_weight_bf16); - free(b->img_mlp_down_weight_bf16); - free(b->txt_q_weight_bf16); - free(b->txt_k_weight_bf16); - free(b->txt_v_weight_bf16); - free(b->txt_proj_weight_bf16); - free(b->txt_mlp_gate_weight_bf16); - free(b->txt_mlp_up_weight_bf16); - free(b->txt_mlp_down_weight_bf16); - } + free(b->txt_mlp_gate_weight_bf16); + free(b->txt_mlp_up_weight_bf16); + free(b->txt_mlp_down_weight_bf16); } free(tf->double_blocks); } @@ -4605,11 +5198,9 @@ void flux_transformer_free(flux_transformer_t *tf) { free(b->norm_q_weight); free(b->norm_k_weight); free(b->qkv_mlp_weight); + free(b->qkv_mlp_weight_bf16); free(b->proj_mlp_weight); - if (!tf->use_mmap) { - free(b->qkv_mlp_weight_bf16); - free(b->proj_mlp_weight_bf16); - } + free(b->proj_mlp_weight_bf16); } free(tf->single_blocks); } @@ -4655,6 +5246,7 @@ void flux_transformer_free(flux_transformer_t *tf) { free(tf->t_emb_silu); free(tf->double_mod_img); free(tf->double_mod_txt); + free(tf->single_mod_params); free(tf->double_img_attn_out); free(tf->double_txt_attn_out); @@ -4668,12 +5260,10 @@ void flux_transformer_free(flux_transformer_t *tf) { free(tf->cached_combined_rope_cos); free(tf->cached_combined_rope_sin); - /* Close safetensors files if in mmap mode */ - if (tf->use_mmap) { - for (int i = 0; i < tf->num_sf_files; i++) { - if (tf->sf_files[i]) safetensors_close(tf->sf_files[i]); - } - tf->num_sf_files = 0; + /* Close safetensors file if in mmap mode */ + if (tf->use_mmap && tf->sf) { + safetensors_close(tf->sf); + tf->sf = NULL; } free(tf); @@ -4683,25 +5273,25 @@ void flux_transformer_free(flux_transformer_t *tf) { * Safetensors Loading * ======================================================================== */ -static float *get_sf_tensor_tf(safetensors_file_t **files, int num_files, const char *name) { - for (int f = 0; f < num_files; f++) { - const safetensor_t *t = safetensors_find(files[f], name); - if (t) return safetensors_get_f32(files[f], t); +static float *get_sf_tensor_tf(safetensors_file_t *sf, const char *name) { + const safetensor_t *t = safetensors_find(sf, name); + if (!t) { + fprintf(stderr, "Error: required tensor %s not found\n", name); + return NULL; } - fprintf(stderr, "Error: required tensor %s not found\n", name); - return NULL; + return safetensors_get_f32(sf, t); } /* Get tensor as bf16 (for GPU acceleration) */ -static uint16_t *get_sf_tensor_bf16(safetensors_file_t **files, int num_files, const char *name) { - for (int f = 0; f < num_files; f++) { - const safetensor_t *t = safetensors_find(files[f], name); - if (t) { - if (!safetensor_is_bf16(t)) return NULL; - return safetensors_get_bf16(files[f], t); - } +static uint16_t *get_sf_tensor_bf16(safetensors_file_t *sf, const char *name) { + const safetensor_t *t = safetensors_find(sf, name); + if (!t) { + return NULL; /* Not an error - bf16 is optional */ + } + if (!safetensor_is_bf16(t)) { + return NULL; /* Not bf16, will use f32 version */ } - return NULL; /* Not found - bf16 is optional */ + return safetensors_get_bf16(sf, t); } #ifdef USE_METAL @@ -4776,44 +5366,43 @@ static void warmup_bf16_weights(flux_transformer_t *tf) { } #endif /* USE_METAL */ -flux_transformer_t *flux_transformer_load_safetensors(const char *model_dir) { +flux_transformer_t *flux_transformer_load_safetensors(safetensors_file_t *sf) { flux_transformer_t *tf = calloc(1, sizeof(flux_transformer_t)); if (!tf) return NULL; char name[256]; - /* Parse config from transformer/config.json, fall back to 4B defaults */ - if (parse_transformer_config(model_dir, tf) != 0) { - tf->hidden_size = 3072; - tf->num_heads = 24; - tf->head_dim = 128; - tf->mlp_hidden = 9216; - tf->num_double_layers = 5; - tf->num_single_layers = 20; - tf->text_dim = 7680; - tf->latent_channels = 128; - tf->rope_theta = 2000.0f; - tf->rope_dim = 128; - tf->axis_dim = 32; - } + /* Set config based on FLUX.2-klein-4B */ + tf->hidden_size = 3072; + tf->num_heads = 24; + tf->head_dim = 128; + tf->mlp_hidden = 9216; + tf->num_double_layers = 5; + tf->num_single_layers = 20; + tf->text_dim = 7680; + tf->latent_channels = 128; + /* Max sequence length must accommodate image + text tokens combined. + * At 1024x1024: img_seq = (1024/8)^2 = 16384, txt_seq = 512, total = 16896 + * At 1792x1792: img_seq = (1792/8)^2 = 50176, txt_seq = 512, total = 50688 + * We set 52000 to support up to 1792x1792 with margin. + */ tf->max_seq_len = 52000; + tf->rope_dim = 128; + tf->rope_theta = 2000.0f; + tf->axis_dim = 32; /* RoPE axis dimension (head_dim = 128 = 4 * axis_dim) */ - /* Open safetensors shards */ - safetensors_file_t *files[MAX_TF_SHARDS]; - int num_files = open_transformer_shards(model_dir, files, MAX_TF_SHARDS); - if (num_files == 0) { - fprintf(stderr, "flux_transformer_load: failed to open safetensors files\n"); - free(tf); - return NULL; - } - - /* Enable bf16 mode if Metal GPU is available */ + /* Enable bf16 mode on GPU backends when available. */ #ifdef USE_METAL tf->use_bf16 = flux_metal_available(); if (tf->use_bf16) { if (flux_verbose) fprintf(stderr, "Using bf16 weights for GPU acceleration\n"); } +#elif defined(USE_CUDA) + tf->use_bf16 = (!getenv("FLUX_CUDA_NO_BF16") && flux_cuda_bf16_linear_available()) ? 1 : 0; + if (tf->use_bf16 && flux_verbose) { + fprintf(stderr, "Using bf16 weights for CUDA tensor-core projections\n"); + } #else tf->use_bf16 = 0; #endif @@ -4823,26 +5412,29 @@ flux_transformer_t *flux_transformer_load_safetensors(const char *model_dir) { /* Input projections */ if (tf->use_bf16) { - tf->img_in_weight_bf16 = get_sf_tensor_bf16(files, num_files, "x_embedder.weight"); - tf->txt_in_weight_bf16 = get_sf_tensor_bf16(files, num_files, "context_embedder.weight"); + tf->img_in_weight_bf16 = get_sf_tensor_bf16(sf, "x_embedder.weight"); + tf->txt_in_weight_bf16 = get_sf_tensor_bf16(sf, "context_embedder.weight"); } else { - tf->img_in_weight = get_sf_tensor_tf(files, num_files, "x_embedder.weight"); - tf->txt_in_weight = get_sf_tensor_tf(files, num_files, "context_embedder.weight"); + tf->img_in_weight = get_sf_tensor_tf(sf, "x_embedder.weight"); + tf->txt_in_weight = get_sf_tensor_tf(sf, "context_embedder.weight"); } - /* Time embedding */ + /* Time embedding + * FLUX.2-klein uses 256-dim sinusoidal embedding (128 frequencies) + * linear_1: [3072, 256], linear_2: [3072, 3072] + */ tf->time_embed.sincos_dim = 256; - tf->time_embed.fc1_weight = get_sf_tensor_tf(files, num_files, + tf->time_embed.fc1_weight = get_sf_tensor_tf(sf, "time_guidance_embed.timestep_embedder.linear_1.weight"); - tf->time_embed.fc2_weight = get_sf_tensor_tf(files, num_files, + tf->time_embed.fc2_weight = get_sf_tensor_tf(sf, "time_guidance_embed.timestep_embedder.linear_2.weight"); /* Modulation weights - these are always needed in f32 for CPU modulation computation */ - tf->adaln_double_img_weight = get_sf_tensor_tf(files, num_files, + tf->adaln_double_img_weight = get_sf_tensor_tf(sf, "double_stream_modulation_img.linear.weight"); - tf->adaln_double_txt_weight = get_sf_tensor_tf(files, num_files, + tf->adaln_double_txt_weight = get_sf_tensor_tf(sf, "double_stream_modulation_txt.linear.weight"); - tf->adaln_single_weight = get_sf_tensor_tf(files, num_files, + tf->adaln_single_weight = get_sf_tensor_tf(sf, "single_stream_modulation.linear.weight"); /* Double blocks */ @@ -4852,29 +5444,29 @@ flux_transformer_t *flux_transformer_load_safetensors(const char *model_dir) { /* Image attention - QK norm weights (always f32) */ snprintf(name, sizeof(name), "transformer_blocks.%d.attn.norm_q.weight", i); - b->img_norm_q_weight = get_sf_tensor_tf(files, num_files, name); + b->img_norm_q_weight = get_sf_tensor_tf(sf, name); snprintf(name, sizeof(name), "transformer_blocks.%d.attn.norm_k.weight", i); - b->img_norm_k_weight = get_sf_tensor_tf(files, num_files, name); + b->img_norm_k_weight = get_sf_tensor_tf(sf, name); /* Image Q, K, V projections (separate) */ snprintf(name, sizeof(name), "transformer_blocks.%d.attn.to_q.weight", i); - if (tf->use_bf16) b->img_q_weight_bf16 = get_sf_tensor_bf16(files, num_files, name); - else b->img_q_weight = get_sf_tensor_tf(files, num_files, name); + if (tf->use_bf16) b->img_q_weight_bf16 = get_sf_tensor_bf16(sf, name); + else b->img_q_weight = get_sf_tensor_tf(sf, name); snprintf(name, sizeof(name), "transformer_blocks.%d.attn.to_k.weight", i); - if (tf->use_bf16) b->img_k_weight_bf16 = get_sf_tensor_bf16(files, num_files, name); - else b->img_k_weight = get_sf_tensor_tf(files, num_files, name); + if (tf->use_bf16) b->img_k_weight_bf16 = get_sf_tensor_bf16(sf, name); + else b->img_k_weight = get_sf_tensor_tf(sf, name); snprintf(name, sizeof(name), "transformer_blocks.%d.attn.to_v.weight", i); - if (tf->use_bf16) b->img_v_weight_bf16 = get_sf_tensor_bf16(files, num_files, name); - else b->img_v_weight = get_sf_tensor_tf(files, num_files, name); + if (tf->use_bf16) b->img_v_weight_bf16 = get_sf_tensor_bf16(sf, name); + else b->img_v_weight = get_sf_tensor_tf(sf, name); snprintf(name, sizeof(name), "transformer_blocks.%d.attn.to_out.0.weight", i); - if (tf->use_bf16) b->img_proj_weight_bf16 = get_sf_tensor_bf16(files, num_files, name); - else b->img_proj_weight = get_sf_tensor_tf(files, num_files, name); + if (tf->use_bf16) b->img_proj_weight_bf16 = get_sf_tensor_bf16(sf, name); + else b->img_proj_weight = get_sf_tensor_tf(sf, name); - /* Image FFN - linear_in contains gate and up fused */ + /* Image FFN - linear_in contains gate and up fused (18432 = 2*9216) */ snprintf(name, sizeof(name), "transformer_blocks.%d.ff.linear_in.weight", i); if (tf->use_bf16) { - uint16_t *ff_in_bf16 = get_sf_tensor_bf16(files, num_files, name); + uint16_t *ff_in_bf16 = get_sf_tensor_bf16(sf, name); if (ff_in_bf16) { b->img_mlp_gate_weight_bf16 = malloc(mlp * h * sizeof(uint16_t)); b->img_mlp_up_weight_bf16 = malloc(mlp * h * sizeof(uint16_t)); @@ -4883,7 +5475,7 @@ flux_transformer_t *flux_transformer_load_safetensors(const char *model_dir) { free(ff_in_bf16); } } else { - float *ff_in = get_sf_tensor_tf(files, num_files, name); + float *ff_in = get_sf_tensor_tf(sf, name); if (ff_in) { b->img_mlp_gate_weight = malloc(mlp * h * sizeof(float)); b->img_mlp_up_weight = malloc(mlp * h * sizeof(float)); @@ -4894,33 +5486,33 @@ flux_transformer_t *flux_transformer_load_safetensors(const char *model_dir) { } snprintf(name, sizeof(name), "transformer_blocks.%d.ff.linear_out.weight", i); - if (tf->use_bf16) b->img_mlp_down_weight_bf16 = get_sf_tensor_bf16(files, num_files, name); - else b->img_mlp_down_weight = get_sf_tensor_tf(files, num_files, name); + if (tf->use_bf16) b->img_mlp_down_weight_bf16 = get_sf_tensor_bf16(sf, name); + else b->img_mlp_down_weight = get_sf_tensor_tf(sf, name); /* Text stream - QK norm weights (always f32) */ snprintf(name, sizeof(name), "transformer_blocks.%d.attn.norm_added_q.weight", i); - b->txt_norm_q_weight = get_sf_tensor_tf(files, num_files, name); + b->txt_norm_q_weight = get_sf_tensor_tf(sf, name); snprintf(name, sizeof(name), "transformer_blocks.%d.attn.norm_added_k.weight", i); - b->txt_norm_k_weight = get_sf_tensor_tf(files, num_files, name); + b->txt_norm_k_weight = get_sf_tensor_tf(sf, name); /* Text Q, K, V projections (separate) */ snprintf(name, sizeof(name), "transformer_blocks.%d.attn.add_q_proj.weight", i); - if (tf->use_bf16) b->txt_q_weight_bf16 = get_sf_tensor_bf16(files, num_files, name); - else b->txt_q_weight = get_sf_tensor_tf(files, num_files, name); + if (tf->use_bf16) b->txt_q_weight_bf16 = get_sf_tensor_bf16(sf, name); + else b->txt_q_weight = get_sf_tensor_tf(sf, name); snprintf(name, sizeof(name), "transformer_blocks.%d.attn.add_k_proj.weight", i); - if (tf->use_bf16) b->txt_k_weight_bf16 = get_sf_tensor_bf16(files, num_files, name); - else b->txt_k_weight = get_sf_tensor_tf(files, num_files, name); + if (tf->use_bf16) b->txt_k_weight_bf16 = get_sf_tensor_bf16(sf, name); + else b->txt_k_weight = get_sf_tensor_tf(sf, name); snprintf(name, sizeof(name), "transformer_blocks.%d.attn.add_v_proj.weight", i); - if (tf->use_bf16) b->txt_v_weight_bf16 = get_sf_tensor_bf16(files, num_files, name); - else b->txt_v_weight = get_sf_tensor_tf(files, num_files, name); + if (tf->use_bf16) b->txt_v_weight_bf16 = get_sf_tensor_bf16(sf, name); + else b->txt_v_weight = get_sf_tensor_tf(sf, name); snprintf(name, sizeof(name), "transformer_blocks.%d.attn.to_add_out.weight", i); - if (tf->use_bf16) b->txt_proj_weight_bf16 = get_sf_tensor_bf16(files, num_files, name); - else b->txt_proj_weight = get_sf_tensor_tf(files, num_files, name); + if (tf->use_bf16) b->txt_proj_weight_bf16 = get_sf_tensor_bf16(sf, name); + else b->txt_proj_weight = get_sf_tensor_tf(sf, name); snprintf(name, sizeof(name), "transformer_blocks.%d.ff_context.linear_in.weight", i); if (tf->use_bf16) { - uint16_t *txt_ff_in_bf16 = get_sf_tensor_bf16(files, num_files, name); + uint16_t *txt_ff_in_bf16 = get_sf_tensor_bf16(sf, name); if (txt_ff_in_bf16) { b->txt_mlp_gate_weight_bf16 = malloc(mlp * h * sizeof(uint16_t)); b->txt_mlp_up_weight_bf16 = malloc(mlp * h * sizeof(uint16_t)); @@ -4929,7 +5521,7 @@ flux_transformer_t *flux_transformer_load_safetensors(const char *model_dir) { free(txt_ff_in_bf16); } } else { - float *txt_ff_in = get_sf_tensor_tf(files, num_files, name); + float *txt_ff_in = get_sf_tensor_tf(sf, name); if (txt_ff_in) { b->txt_mlp_gate_weight = malloc(mlp * h * sizeof(float)); b->txt_mlp_up_weight = malloc(mlp * h * sizeof(float)); @@ -4940,8 +5532,8 @@ flux_transformer_t *flux_transformer_load_safetensors(const char *model_dir) { } snprintf(name, sizeof(name), "transformer_blocks.%d.ff_context.linear_out.weight", i); - if (tf->use_bf16) b->txt_mlp_down_weight_bf16 = get_sf_tensor_bf16(files, num_files, name); - else b->txt_mlp_down_weight = get_sf_tensor_tf(files, num_files, name); + if (tf->use_bf16) b->txt_mlp_down_weight_bf16 = get_sf_tensor_bf16(sf, name); + else b->txt_mlp_down_weight = get_sf_tensor_tf(sf, name); } /* Single blocks */ @@ -4951,31 +5543,28 @@ flux_transformer_t *flux_transformer_load_safetensors(const char *model_dir) { /* QK norm weights (always f32, small) */ snprintf(name, sizeof(name), "single_transformer_blocks.%d.attn.norm_q.weight", i); - b->norm_q_weight = get_sf_tensor_tf(files, num_files, name); + b->norm_q_weight = get_sf_tensor_tf(sf, name); snprintf(name, sizeof(name), "single_transformer_blocks.%d.attn.norm_k.weight", i); - b->norm_k_weight = get_sf_tensor_tf(files, num_files, name); + b->norm_k_weight = get_sf_tensor_tf(sf, name); /* Major linear weights - load bf16 or f32 based on mode */ snprintf(name, sizeof(name), "single_transformer_blocks.%d.attn.to_qkv_mlp_proj.weight", i); - if (tf->use_bf16) b->qkv_mlp_weight_bf16 = get_sf_tensor_bf16(files, num_files, name); - else b->qkv_mlp_weight = get_sf_tensor_tf(files, num_files, name); + if (tf->use_bf16) b->qkv_mlp_weight_bf16 = get_sf_tensor_bf16(sf, name); + else b->qkv_mlp_weight = get_sf_tensor_tf(sf, name); snprintf(name, sizeof(name), "single_transformer_blocks.%d.attn.to_out.weight", i); - if (tf->use_bf16) b->proj_mlp_weight_bf16 = get_sf_tensor_bf16(files, num_files, name); - else b->proj_mlp_weight = get_sf_tensor_tf(files, num_files, name); + if (tf->use_bf16) b->proj_mlp_weight_bf16 = get_sf_tensor_bf16(sf, name); + else b->proj_mlp_weight = get_sf_tensor_tf(sf, name); } /* Final layer */ - tf->final_norm_weight = get_sf_tensor_tf(files, num_files, "norm_out.linear.weight"); + tf->final_norm_weight = get_sf_tensor_tf(sf, "norm_out.linear.weight"); if (tf->use_bf16) { - tf->final_proj_weight_bf16 = get_sf_tensor_bf16(files, num_files, "proj_out.weight"); + tf->final_proj_weight_bf16 = get_sf_tensor_bf16(sf, "proj_out.weight"); } else { - tf->final_proj_weight = get_sf_tensor_tf(files, num_files, "proj_out.weight"); + tf->final_proj_weight = get_sf_tensor_tf(sf, "proj_out.weight"); } - /* Close safetensors files (non-mmap: data already copied) */ - for (int i = 0; i < num_files; i++) safetensors_close(files[i]); - /* Precompute RoPE frequencies */ tf->rope_freqs = malloc(tf->max_seq_len * tf->head_dim * sizeof(float)); if (tf->rope_freqs) { @@ -5015,8 +5604,9 @@ flux_transformer_t *flux_transformer_load_safetensors(const char *model_dir) { tf->t_emb_silu = malloc(hidden * sizeof(float)); tf->double_mod_img = malloc(hidden * 6 * sizeof(float)); tf->double_mod_txt = malloc(hidden * 6 * sizeof(float)); + tf->single_mod_params = malloc(hidden * 3 * sizeof(float)); - if (!tf->t_emb_silu || !tf->double_mod_img || !tf->double_mod_txt) { + if (!tf->t_emb_silu || !tf->double_mod_img || !tf->double_mod_txt || !tf->single_mod_params) { flux_transformer_free(tf); return NULL; } @@ -5029,78 +5619,73 @@ flux_transformer_t *flux_transformer_load_safetensors(const char *model_dir) { return tf; } -/* Load transformer in mmap mode - only load small weights, keep files open for block loading */ -flux_transformer_t *flux_transformer_load_safetensors_mmap(const char *model_dir) { +/* Load transformer in mmap mode - only load small weights, keep sf open for block loading */ +flux_transformer_t *flux_transformer_load_safetensors_mmap(safetensors_file_t *sf) { flux_transformer_t *tf = calloc(1, sizeof(flux_transformer_t)); if (!tf) return NULL; - /* Parse config from transformer/config.json, fall back to 4B defaults */ - if (parse_transformer_config(model_dir, tf) != 0) { - tf->hidden_size = 3072; - tf->num_heads = 24; - tf->head_dim = 128; - tf->mlp_hidden = 9216; - tf->num_double_layers = 5; - tf->num_single_layers = 20; - tf->text_dim = 7680; - tf->latent_channels = 128; - tf->rope_theta = 2000.0f; - tf->rope_dim = 128; - tf->axis_dim = 32; - } + /* Set config based on FLUX.2-klein-4B */ + tf->hidden_size = 3072; + tf->num_heads = 24; + tf->head_dim = 128; + tf->mlp_hidden = 9216; + tf->num_double_layers = 5; + tf->num_single_layers = 20; + tf->text_dim = 7680; + tf->latent_channels = 128; tf->max_seq_len = 52000; /* Support up to 1792x1792 */ + tf->rope_dim = 128; + tf->rope_theta = 2000.0f; + tf->axis_dim = 32; /* RoPE axis dimension (head_dim = 128 = 4 * axis_dim) */ - /* Open safetensors shards and keep them open for on-demand loading */ + /* Enable mmap mode - keep sf open, don't load block weights yet */ tf->use_mmap = 1; - tf->num_sf_files = open_transformer_shards(model_dir, tf->sf_files, MAX_TF_SHARDS); - if (tf->num_sf_files == 0) { - fprintf(stderr, "flux_transformer_load_mmap: failed to open safetensors files\n"); - free(tf); - return NULL; - } + tf->sf = sf; - /* Enable bf16 mode if Metal GPU is available */ + /* Enable bf16 mode on GPU backends when available. */ #ifdef USE_METAL tf->use_bf16 = flux_metal_available(); if (tf->use_bf16) { if (flux_verbose) fprintf(stderr, "Using bf16 weights for GPU acceleration (mmap mode)\n"); } +#elif defined(USE_CUDA) + tf->use_bf16 = (!getenv("FLUX_CUDA_NO_BF16") && flux_cuda_bf16_linear_available()) ? 1 : 0; + if (tf->use_bf16 && flux_verbose) { + fprintf(stderr, "Using bf16 weights for CUDA tensor-core projections (mmap mode)\n"); + } #else tf->use_bf16 = 0; #endif - safetensors_file_t **files = tf->sf_files; - int num_files = tf->num_sf_files; - /* Input projections - always load (small) */ - tf->img_in_weight = get_sf_tensor_tf(files, num_files, "x_embedder.weight"); - tf->txt_in_weight = get_sf_tensor_tf(files, num_files, "context_embedder.weight"); + tf->img_in_weight = get_sf_tensor_tf(sf, "x_embedder.weight"); + tf->txt_in_weight = get_sf_tensor_tf(sf, "context_embedder.weight"); if (tf->use_bf16) { - tf->img_in_weight_bf16 = get_sf_tensor_bf16(files, num_files, "x_embedder.weight"); - tf->txt_in_weight_bf16 = get_sf_tensor_bf16(files, num_files, "context_embedder.weight"); + tf->img_in_weight_bf16 = get_sf_tensor_bf16(sf, "x_embedder.weight"); + tf->txt_in_weight_bf16 = get_sf_tensor_bf16(sf, "context_embedder.weight"); } /* Time embedding - always load (small) */ tf->time_embed.sincos_dim = 256; - tf->time_embed.fc1_weight = get_sf_tensor_tf(files, num_files, + tf->time_embed.fc1_weight = get_sf_tensor_tf(sf, "time_guidance_embed.timestep_embedder.linear_1.weight"); - tf->time_embed.fc2_weight = get_sf_tensor_tf(files, num_files, + tf->time_embed.fc2_weight = get_sf_tensor_tf(sf, "time_guidance_embed.timestep_embedder.linear_2.weight"); /* Modulation weights - always load */ - tf->adaln_double_img_weight = get_sf_tensor_tf(files, num_files, + tf->adaln_double_img_weight = get_sf_tensor_tf(sf, "double_stream_modulation_img.linear.weight"); - tf->adaln_double_txt_weight = get_sf_tensor_tf(files, num_files, + tf->adaln_double_txt_weight = get_sf_tensor_tf(sf, "double_stream_modulation_txt.linear.weight"); - tf->adaln_single_weight = get_sf_tensor_tf(files, num_files, + tf->adaln_single_weight = get_sf_tensor_tf(sf, "single_stream_modulation.linear.weight"); if (tf->use_bf16) { - tf->adaln_double_img_weight_bf16 = get_sf_tensor_bf16(files, num_files, + tf->adaln_double_img_weight_bf16 = get_sf_tensor_bf16(sf, "double_stream_modulation_img.linear.weight"); - tf->adaln_double_txt_weight_bf16 = get_sf_tensor_bf16(files, num_files, + tf->adaln_double_txt_weight_bf16 = get_sf_tensor_bf16(sf, "double_stream_modulation_txt.linear.weight"); - tf->adaln_single_weight_bf16 = get_sf_tensor_bf16(files, num_files, + tf->adaln_single_weight_bf16 = get_sf_tensor_bf16(sf, "single_stream_modulation.linear.weight"); } @@ -5109,10 +5694,10 @@ flux_transformer_t *flux_transformer_load_safetensors_mmap(const char *model_dir tf->single_blocks = calloc(tf->num_single_layers, sizeof(single_block_t)); /* Final layer - always load (small) */ - tf->final_norm_weight = get_sf_tensor_tf(files, num_files, "norm_out.linear.weight"); - tf->final_proj_weight = get_sf_tensor_tf(files, num_files, "proj_out.weight"); + tf->final_norm_weight = get_sf_tensor_tf(sf, "norm_out.linear.weight"); + tf->final_proj_weight = get_sf_tensor_tf(sf, "proj_out.weight"); if (tf->use_bf16) { - tf->final_proj_weight_bf16 = get_sf_tensor_bf16(files, num_files, "proj_out.weight"); + tf->final_proj_weight_bf16 = get_sf_tensor_bf16(sf, "proj_out.weight"); } /* Precompute RoPE frequencies */ @@ -5153,8 +5738,9 @@ flux_transformer_t *flux_transformer_load_safetensors_mmap(const char *model_dir tf->t_emb_silu = malloc(hidden * sizeof(float)); tf->double_mod_img = malloc(hidden * 6 * sizeof(float)); tf->double_mod_txt = malloc(hidden * 6 * sizeof(float)); + tf->single_mod_params = malloc(hidden * 3 * sizeof(float)); - if (!tf->t_emb_silu || !tf->double_mod_img || !tf->double_mod_txt) { + if (!tf->t_emb_silu || !tf->double_mod_img || !tf->double_mod_txt || !tf->single_mod_params) { flux_transformer_free(tf); return NULL; } @@ -5168,3 +5754,9 @@ flux_transformer_t *flux_transformer_load_safetensors_mmap(const char *model_dir return tf; } + +/* Compatibility hook used by sampler in branches that cache mmap weights + * across denoising steps. This implementation has no persistent mmap cache. */ +void flux_transformer_free_mmap_cache(flux_transformer_t *tf) { + (void)tf; +} diff --git a/flux_vae.c b/flux_vae.c index 1adc995..a872ecb 100644 --- a/flux_vae.c +++ b/flux_vae.c @@ -14,6 +14,10 @@ #include "flux.h" #include "flux_kernels.h" #include "flux_safetensors.h" +#ifdef USE_CUDA +#include "flux_cuda.h" +#include +#endif #ifdef USE_METAL #include "flux_metal.h" #endif @@ -115,6 +119,15 @@ typedef struct flux_vae { int max_h, max_w; float *work1, *work2, *work3; size_t work_size; + + /* Reusable attention scratch (avoid per-call malloc/free in mid-attention). */ + float *attn_q_t; + float *attn_k_t; + float *attn_v_t; + float *attn_o_t; + float *attn_scores; + size_t attn_qkv_capacity; /* floats per q/k/v/o buffer */ + size_t attn_scores_capacity; /* floats in scores buffer */ } flux_vae_t; /* Forward declarations */ @@ -175,6 +188,39 @@ static void swish_inplace(float *x, int n) { flux_silu(x, n); } +static int vae_ensure_attn_scratch(flux_vae_t *vae, int spatial, int ch) { + size_t qkv_need = (size_t)spatial * (size_t)ch; + size_t scores_need = (size_t)spatial * (size_t)spatial; + + if (qkv_need == 0 || scores_need == 0) return 0; + + if (vae->attn_qkv_capacity < qkv_need) { + float *tmp = NULL; + tmp = (float *)realloc(vae->attn_q_t, qkv_need * sizeof(float)); + if (!tmp) return 0; + vae->attn_q_t = tmp; + tmp = (float *)realloc(vae->attn_k_t, qkv_need * sizeof(float)); + if (!tmp) return 0; + vae->attn_k_t = tmp; + tmp = (float *)realloc(vae->attn_v_t, qkv_need * sizeof(float)); + if (!tmp) return 0; + vae->attn_v_t = tmp; + tmp = (float *)realloc(vae->attn_o_t, qkv_need * sizeof(float)); + if (!tmp) return 0; + vae->attn_o_t = tmp; + vae->attn_qkv_capacity = qkv_need; + } + + if (vae->attn_scores_capacity < scores_need) { + float *new_scores = (float *)realloc(vae->attn_scores, scores_need * sizeof(float)); + if (!new_scores) return 0; + vae->attn_scores = new_scores; + vae->attn_scores_capacity = scores_need; + } + + return 1; +} + /* Apply residual block */ static void resblock_forward(float *out, const float *x, const vae_resblock_t *block, @@ -220,7 +266,7 @@ static void resblock_forward(float *out, const float *x, /* Apply self-attention block */ /* Returns 0 on success, -1 on OOM */ -static int attnblock_forward(float *out, const float *x, +static int attnblock_forward(flux_vae_t *vae, float *out, const float *x, const vae_attnblock_t *block, float *work, int batch, int H, int W, int num_groups, float eps) { @@ -249,23 +295,16 @@ static int attnblock_forward(float *out, const float *x, float *attn_out = v + batch * ch * spatial; - /* Allocate attention work buffers once outside the batch loop */ - float *q_t = (float *)malloc(spatial * ch * sizeof(float)); - float *k_t = (float *)malloc(spatial * ch * sizeof(float)); - float *v_t = (float *)malloc(spatial * ch * sizeof(float)); - float *o_t = (float *)malloc(spatial * ch * sizeof(float)); - float *scores = (float *)malloc((size_t)spatial * spatial * sizeof(float)); - - /* Check for allocation failures */ - if (!q_t || !k_t || !v_t || !o_t || !scores) { - free(q_t); - free(k_t); - free(v_t); - free(o_t); - free(scores); + if (!vae_ensure_attn_scratch(vae, spatial, ch)) { return -1; /* OOM */ } + float *q_t = vae->attn_q_t; + float *k_t = vae->attn_k_t; + float *v_t = vae->attn_v_t; + float *o_t = vae->attn_o_t; + float *scores = vae->attn_scores; + for (int b = 0; b < batch; b++) { float *qb = q + b * ch * spatial; float *kb = k + b * ch * spatial; @@ -275,20 +314,35 @@ static int attnblock_forward(float *out, const float *x, /* Transpose [C, HW] -> [HW, C] */ for (int c = 0; c < ch; c++) { for (int i = 0; i < spatial; i++) { - q_t[i * ch + c] = qb[c * spatial + i] * scale; + q_t[i * ch + c] = qb[c * spatial + i]; k_t[i * ch + c] = kb[c * spatial + i]; v_t[i * ch + c] = vb[c * spatial + i]; } } - /* Q @ K^T using BLAS: [HW, C] @ [C, HW] -> [HW, HW] */ - flux_matmul_t(scores, q_t, k_t, spatial, ch, spatial); + int used_cuda_attention = 0; +#ifdef USE_CUDA + if (!getenv("FLUX_CUDA_NO_VAE_ATTN")) { + used_cuda_attention = flux_cuda_attention_single(o_t, q_t, k_t, v_t, + spatial, spatial, ch, + scale, 0, NULL, 0); + } +#endif - /* Softmax */ - flux_softmax(scores, spatial, spatial); + if (!used_cuda_attention) { + /* Scale Q before BLAS fallback. */ + int q_elems = spatial * ch; + for (int i = 0; i < q_elems; i++) q_t[i] *= scale; - /* scores @ V using BLAS: [HW, HW] @ [HW, C] -> [HW, C] */ - flux_matmul(o_t, scores, v_t, spatial, spatial, ch); + /* Q @ K^T using BLAS: [HW, C] @ [C, HW] -> [HW, HW] */ + flux_matmul_t(scores, q_t, k_t, spatial, ch, spatial); + + /* Softmax */ + flux_softmax(scores, spatial, spatial); + + /* scores @ V using BLAS: [HW, HW] @ [HW, C] -> [HW, C] */ + flux_matmul(o_t, scores, v_t, spatial, spatial, ch); + } /* Transpose output back [HW, C] -> [C, HW] */ for (int c = 0; c < ch; c++) { @@ -298,12 +352,6 @@ static int attnblock_forward(float *out, const float *x, } } - free(q_t); - free(k_t); - free(v_t); - free(o_t); - free(scores); - /* Project output */ vae_conv2d(work, attn_out, block->out_weight, block->out_bias, batch, ch, ch, H, W, 1, 1, 1, 0); @@ -382,7 +430,7 @@ float *flux_vae_encode(flux_vae_t *vae, const float *img, batch, cur_h, cur_w, vae->num_groups, vae->eps); if (flux_vae_progress_callback) flux_vae_progress_callback(progress++, total_blocks); - if (attnblock_forward(x, work, &vae->enc_mid_attn, vae->work3, + if (attnblock_forward(vae, x, work, &vae->enc_mid_attn, vae->work3, batch, cur_h, cur_w, vae->num_groups, vae->eps) < 0) { return NULL; /* OOM in attention */ } @@ -575,7 +623,7 @@ static flux_image *vae_decode_gpu(flux_vae_t *vae, const float *latent, /* Run attention on CPU (uses existing attnblock_forward) */ float *cpu_attn_out = cpu_x; - if (attnblock_forward(cpu_attn_out, cpu_attn_in, &vae->dec_mid_attn, + if (attnblock_forward(vae, cpu_attn_out, cpu_attn_in, &vae->dec_mid_attn, vae->work3, batch, cur_h, cur_w, vae->num_groups, vae->eps) < 0) { flux_gpu_tensor_free(x); @@ -683,6 +731,442 @@ static flux_image *vae_decode_gpu(flux_vae_t *vae, const float *latent, #endif /* USE_METAL */ +#ifdef USE_CUDA +typedef struct { + float *d_col; size_t col_cap; + float *d_rows; size_t rows_cap; + float *d_bias; size_t bias_cap; + float *d_gamma; size_t gamma_cap; + float *d_beta; size_t beta_cap; + float *d_tmp1; size_t tmp1_cap; + float *d_tmp2; size_t tmp2_cap; + float *d_q; size_t q_cap; + float *d_k; size_t k_cap; + float *d_v; size_t v_cap; + float *d_attn; size_t attn_cap; + float *d_q_rows; size_t q_rows_cap; + float *d_k_rows; size_t k_rows_cap; + float *d_v_rows; size_t v_rows_cap; + float *d_o_rows; size_t o_rows_cap; + size_t tile_bytes; +} vae_cuda_decode_scratch_t; + +static void vae_cuda_scratch_free(vae_cuda_decode_scratch_t *s) { + if (!s) return; + if (s->d_col) cudaFree(s->d_col); + if (s->d_rows) cudaFree(s->d_rows); + if (s->d_bias) cudaFree(s->d_bias); + if (s->d_gamma) cudaFree(s->d_gamma); + if (s->d_beta) cudaFree(s->d_beta); + if (s->d_tmp1) cudaFree(s->d_tmp1); + if (s->d_tmp2) cudaFree(s->d_tmp2); + if (s->d_q) cudaFree(s->d_q); + if (s->d_k) cudaFree(s->d_k); + if (s->d_v) cudaFree(s->d_v); + if (s->d_attn) cudaFree(s->d_attn); + if (s->d_q_rows) cudaFree(s->d_q_rows); + if (s->d_k_rows) cudaFree(s->d_k_rows); + if (s->d_v_rows) cudaFree(s->d_v_rows); + if (s->d_o_rows) cudaFree(s->d_o_rows); + memset(s, 0, sizeof(*s)); +} + +static int vae_cuda_ensure_f32(float **ptr, size_t *cap_elems, size_t need_elems) { + if (*cap_elems >= need_elems) return 1; + if (*ptr) cudaFree(*ptr); + *ptr = NULL; + *cap_elems = 0; + if (cudaMalloc((void **)ptr, need_elems * sizeof(float)) != cudaSuccess) return 0; + *cap_elems = need_elems; + return 1; +} + +static int vae_cuda_group_norm(float *d_out, const float *d_x, + const float *gamma, const float *beta, + int batch, int channels, int H, int W, + int num_groups, float eps, + vae_cuda_decode_scratch_t *s) { + if (!vae_cuda_ensure_f32(&s->d_gamma, &s->gamma_cap, channels)) return 0; + if (!vae_cuda_ensure_f32(&s->d_beta, &s->beta_cap, channels)) return 0; + if (cudaMemcpy(s->d_gamma, gamma, (size_t)channels * sizeof(float), + cudaMemcpyHostToDevice) != cudaSuccess) return 0; + if (cudaMemcpy(s->d_beta, beta, (size_t)channels * sizeof(float), + cudaMemcpyHostToDevice) != cudaSuccess) return 0; + return flux_cuda_group_norm_nchw_device(d_out, d_x, s->d_gamma, s->d_beta, + batch, channels, H, W, num_groups, eps); +} + +static int vae_conv2d_cuda_tiled(float *d_out, const float *d_in, + const float *weight, const float *bias, + int in_ch, int out_ch, int H, int W, + int kH, int kW, int stride, int padding, + vae_cuda_decode_scratch_t *s) { + int outH = (H + 2 * padding - kH) / stride + 1; + int outW = (W + 2 * padding - kW) / stride + 1; + int K = in_ch * kH * kW; + if (outH <= 0 || outW <= 0 || K <= 0) return 0; + + size_t limit = s->tile_bytes ? s->tile_bytes : (size_t)64 * 1024 * 1024; + size_t by_col = limit / (sizeof(float) * (size_t)K); + size_t by_rows = limit / (sizeof(float) * (size_t)out_ch); + size_t tile_pixels_max = by_col < by_rows ? by_col : by_rows; + if (tile_pixels_max < (size_t)outW) tile_pixels_max = (size_t)outW; + int tile_rows = (int)(tile_pixels_max / (size_t)outW); + if (tile_rows < 1) tile_rows = 1; + if (tile_rows > outH) tile_rows = outH; + + if (bias) { + if (!vae_cuda_ensure_f32(&s->d_bias, &s->bias_cap, out_ch)) return 0; + if (cudaMemcpy(s->d_bias, bias, (size_t)out_ch * sizeof(float), + cudaMemcpyHostToDevice) != cudaSuccess) return 0; + } + + for (int row = 0; row < outH; row += tile_rows) { + int tile_h = tile_rows; + if (row + tile_h > outH) tile_h = outH - row; + int tile_pixels = tile_h * outW; + size_t col_elems = (size_t)tile_pixels * K; + size_t row_elems = (size_t)tile_pixels * out_ch; + + if (!vae_cuda_ensure_f32(&s->d_col, &s->col_cap, col_elems)) return 0; + if (!vae_cuda_ensure_f32(&s->d_rows, &s->rows_cap, row_elems)) return 0; + + if (!flux_cuda_im2col_nchw_rows_device(s->d_col, d_in, + in_ch, H, W, kH, kW, stride, padding, + outH, outW, row, tile_h)) return 0; + + if (!flux_cuda_linear_nobias_device(s->d_rows, s->d_col, weight, + tile_pixels, K, out_ch)) return 0; + + if (bias && !flux_cuda_add_bias_rows_device(s->d_rows, s->d_bias, + tile_pixels, out_ch)) return 0; + + if (!flux_cuda_rows_to_nchw_tile_device(d_out, s->d_rows, + out_ch, outH, outW, row, tile_h)) return 0; + } + + return 1; +} + +static float *vae_resblock_forward_cuda(const float *d_x, const vae_resblock_t *block, + int batch, int H, int W, + int num_groups, float eps, + vae_cuda_decode_scratch_t *s) { + int in_ch = block->in_channels; + int out_ch = block->out_channels; + int spatial = H * W; + size_t in_elems = (size_t)batch * in_ch * spatial; + size_t out_elems = (size_t)batch * out_ch * spatial; + + if (!vae_cuda_ensure_f32(&s->d_tmp1, &s->tmp1_cap, out_elems > in_elems ? out_elems : in_elems)) + return NULL; + if (!vae_cuda_ensure_f32(&s->d_tmp2, &s->tmp2_cap, out_elems)) + return NULL; + + float *d_skip = NULL; + if (cudaMalloc((void **)&d_skip, out_elems * sizeof(float)) != cudaSuccess) return NULL; + + if (in_ch != out_ch) { + if (!vae_conv2d_cuda_tiled(d_skip, d_x, block->skip_weight, block->skip_bias, + in_ch, out_ch, H, W, 1, 1, 1, 0, s)) { + cudaFree(d_skip); + return NULL; + } + } else { + if (cudaMemcpy(d_skip, d_x, out_elems * sizeof(float), + cudaMemcpyDeviceToDevice) != cudaSuccess) { + cudaFree(d_skip); + return NULL; + } + } + + if (!vae_cuda_group_norm(s->d_tmp1, d_x, block->norm1_weight, block->norm1_bias, + batch, in_ch, H, W, num_groups, eps, s)) { + cudaFree(d_skip); + return NULL; + } + if (!flux_cuda_silu_device(s->d_tmp1, (int)in_elems)) { + cudaFree(d_skip); + return NULL; + } + if (!vae_conv2d_cuda_tiled(s->d_tmp2, s->d_tmp1, block->conv1_weight, block->conv1_bias, + in_ch, out_ch, H, W, 3, 3, 1, 1, s)) { + cudaFree(d_skip); + return NULL; + } + if (!vae_cuda_group_norm(s->d_tmp1, s->d_tmp2, block->norm2_weight, block->norm2_bias, + batch, out_ch, H, W, num_groups, eps, s)) { + cudaFree(d_skip); + return NULL; + } + if (!flux_cuda_silu_device(s->d_tmp1, (int)out_elems)) { + cudaFree(d_skip); + return NULL; + } + if (!vae_conv2d_cuda_tiled(s->d_tmp2, s->d_tmp1, block->conv2_weight, block->conv2_bias, + out_ch, out_ch, H, W, 3, 3, 1, 1, s)) { + cudaFree(d_skip); + return NULL; + } + if (!flux_cuda_add_inplace_device(d_skip, s->d_tmp2, (int)out_elems)) { + cudaFree(d_skip); + return NULL; + } + + return d_skip; +} + +static float *vae_attnblock_forward_cuda(const float *d_x, const vae_attnblock_t *block, + int batch, int H, int W, + int num_groups, float eps, + vae_cuda_decode_scratch_t *s) { + if (batch != 1) return NULL; /* Current CUDA attention path handles batch=1 decode. */ + + int ch = block->channels; + int spatial = H * W; + size_t n = (size_t)ch * spatial; + + if (!vae_cuda_ensure_f32(&s->d_tmp1, &s->tmp1_cap, n)) return NULL; + if (!vae_cuda_ensure_f32(&s->d_tmp2, &s->tmp2_cap, n)) return NULL; + if (!vae_cuda_ensure_f32(&s->d_q, &s->q_cap, n)) return NULL; + if (!vae_cuda_ensure_f32(&s->d_k, &s->k_cap, n)) return NULL; + if (!vae_cuda_ensure_f32(&s->d_v, &s->v_cap, n)) return NULL; + if (!vae_cuda_ensure_f32(&s->d_attn, &s->attn_cap, n)) return NULL; + + size_t rows_elems = (size_t)spatial * ch; + if (!vae_cuda_ensure_f32(&s->d_q_rows, &s->q_rows_cap, rows_elems)) return NULL; + if (!vae_cuda_ensure_f32(&s->d_k_rows, &s->k_rows_cap, rows_elems)) return NULL; + if (!vae_cuda_ensure_f32(&s->d_v_rows, &s->v_rows_cap, rows_elems)) return NULL; + if (!vae_cuda_ensure_f32(&s->d_o_rows, &s->o_rows_cap, rows_elems)) return NULL; + + if (!vae_cuda_group_norm(s->d_tmp1, d_x, block->norm_weight, block->norm_bias, + batch, ch, H, W, num_groups, eps, s)) return NULL; + + if (!vae_conv2d_cuda_tiled(s->d_q, s->d_tmp1, block->q_weight, block->q_bias, + ch, ch, H, W, 1, 1, 1, 0, s)) return NULL; + if (!vae_conv2d_cuda_tiled(s->d_k, s->d_tmp1, block->k_weight, block->k_bias, + ch, ch, H, W, 1, 1, 1, 0, s)) return NULL; + if (!vae_conv2d_cuda_tiled(s->d_v, s->d_tmp1, block->v_weight, block->v_bias, + ch, ch, H, W, 1, 1, 1, 0, s)) return NULL; + + if (!flux_cuda_nchw_to_rows_device(s->d_q_rows, s->d_q, ch, H, W)) return NULL; + if (!flux_cuda_nchw_to_rows_device(s->d_k_rows, s->d_k, ch, H, W)) return NULL; + if (!flux_cuda_nchw_to_rows_device(s->d_v_rows, s->d_v, ch, H, W)) return NULL; + + float scale = 1.0f / sqrtf((float)ch); + if (!flux_cuda_attention_batched_shd_device(s->d_o_rows, s->d_q_rows, s->d_k_rows, s->d_v_rows, + 1, spatial, spatial, ch, scale, 0, NULL)) { + return NULL; + } + if (!flux_cuda_rows_to_nchw_device(s->d_attn, s->d_o_rows, ch, H, W)) return NULL; + + if (!vae_conv2d_cuda_tiled(s->d_tmp2, s->d_attn, block->out_weight, block->out_bias, + ch, ch, H, W, 1, 1, 1, 0, s)) return NULL; + + float *d_out = NULL; + if (cudaMalloc((void **)&d_out, n * sizeof(float)) != cudaSuccess) return NULL; + if (cudaMemcpy(d_out, d_x, n * sizeof(float), cudaMemcpyDeviceToDevice) != cudaSuccess) { + cudaFree(d_out); + return NULL; + } + if (!flux_cuda_add_inplace_device(d_out, s->d_tmp2, (int)n)) { + cudaFree(d_out); + return NULL; + } + + return d_out; +} + +static flux_image *vae_decode_cuda(flux_vae_t *vae, const float *latent, + int batch, int latent_h, int latent_w) { + if (!latent || !vae) return NULL; + if (batch != 1) return NULL; + + int ch_mult[4] = {1, 2, 4, 4}; + int z_spatial = latent_h * latent_w; + float *cpu_x = vae->work1; + float *cpu_work = vae->work2; + + /* CPU pre-processing: denorm + unpatchify (small tensors). */ + flux_copy(cpu_x, latent, batch * FLUX_LATENT_CHANNELS * z_spatial); + for (int b = 0; b < batch; b++) { + for (int c = 0; c < FLUX_LATENT_CHANNELS; c++) { + float mean = vae->bn_mean[c]; + float std = sqrtf(vae->bn_var[c] + vae->eps); + for (int i = 0; i < z_spatial; i++) { + int idx = b * FLUX_LATENT_CHANNELS * z_spatial + c * z_spatial + i; + cpu_x[idx] = cpu_x[idx] * std + mean; + } + } + } + + int cur_h = latent_h * 2; + int cur_w = latent_w * 2; + flux_unpatchify(cpu_work, cpu_x, batch, vae->z_channels, latent_h, latent_w, 2); + flux_copy(cpu_x, cpu_work, batch * vae->z_channels * cur_h * cur_w); + + float *d_x = NULL; + size_t x_elems = (size_t)batch * vae->z_channels * cur_h * cur_w; + if (cudaMalloc((void **)&d_x, x_elems * sizeof(float)) != cudaSuccess) return NULL; + if (cudaMemcpy(d_x, cpu_x, x_elems * sizeof(float), cudaMemcpyHostToDevice) != cudaSuccess) { + cudaFree(d_x); + return NULL; + } + + vae_cuda_decode_scratch_t s; + memset(&s, 0, sizeof(s)); + s.tile_bytes = (size_t)64 * 1024 * 1024; + const char *tile_env = getenv("FLUX_CUDA_CONV_TILE_MB"); + if (tile_env) { + long mb = strtol(tile_env, NULL, 10); + if (mb >= 4) s.tile_bytes = (size_t)mb * 1024 * 1024; + } + + int progress = 0; + int total_blocks = 3 + 4 * (vae->num_res_blocks + 1); + float *d_next = NULL; + +#define CUDA_DECODE_FAIL do { if (d_x) cudaFree(d_x); if (d_next) cudaFree(d_next); vae_cuda_scratch_free(&s); return NULL; } while (0) + + /* Post-quantization conv (1x1): 32 -> 32 */ + x_elems = (size_t)batch * vae->z_channels * cur_h * cur_w; + if (cudaMalloc((void **)&d_next, x_elems * sizeof(float)) != cudaSuccess) CUDA_DECODE_FAIL; + if (!vae_conv2d_cuda_tiled(d_next, d_x, vae->post_quant_conv_weight, vae->post_quant_conv_bias, + vae->z_channels, vae->z_channels, cur_h, cur_w, 1, 1, 1, 0, &s)) + CUDA_DECODE_FAIL; + cudaFree(d_x); + d_x = d_next; + d_next = NULL; + + /* Conv in: 32 -> 512 */ + int mid_ch = vae->base_channels * ch_mult[3]; + x_elems = (size_t)batch * mid_ch * cur_h * cur_w; + if (cudaMalloc((void **)&d_next, x_elems * sizeof(float)) != cudaSuccess) CUDA_DECODE_FAIL; + if (!vae_conv2d_cuda_tiled(d_next, d_x, vae->dec_conv_in_weight, vae->dec_conv_in_bias, + vae->z_channels, mid_ch, cur_h, cur_w, 3, 3, 1, 1, &s)) + CUDA_DECODE_FAIL; + cudaFree(d_x); + d_x = d_next; + d_next = NULL; + + /* Mid block: resblock -> attn -> resblock */ + d_next = vae_resblock_forward_cuda(d_x, &vae->dec_mid_block1, batch, cur_h, cur_w, + vae->num_groups, vae->eps, &s); + if (!d_next) CUDA_DECODE_FAIL; + cudaFree(d_x); + d_x = d_next; + d_next = NULL; + if (flux_vae_progress_callback) flux_vae_progress_callback(progress++, total_blocks); + + d_next = vae_attnblock_forward_cuda(d_x, &vae->dec_mid_attn, batch, cur_h, cur_w, + vae->num_groups, vae->eps, &s); + if (!d_next) CUDA_DECODE_FAIL; + cudaFree(d_x); + d_x = d_next; + d_next = NULL; + if (flux_vae_progress_callback) flux_vae_progress_callback(progress++, total_blocks); + + d_next = vae_resblock_forward_cuda(d_x, &vae->dec_mid_block2, batch, cur_h, cur_w, + vae->num_groups, vae->eps, &s); + if (!d_next) CUDA_DECODE_FAIL; + cudaFree(d_x); + d_x = d_next; + d_next = NULL; + if (flux_vae_progress_callback) flux_vae_progress_callback(progress++, total_blocks); + + int block_idx = 0; + int up_idx = 0; + for (int level = 3; level >= 0; level--) { + int ch_out = vae->base_channels * ch_mult[level]; + + for (int r = 0; r < vae->num_res_blocks + 1; r++) { + vae_resblock_t *block = &vae->dec_up_blocks[block_idx++]; + d_next = vae_resblock_forward_cuda(d_x, block, batch, cur_h, cur_w, + vae->num_groups, vae->eps, &s); + if (!d_next) CUDA_DECODE_FAIL; + cudaFree(d_x); + d_x = d_next; + d_next = NULL; + if (flux_vae_progress_callback) flux_vae_progress_callback(progress++, total_blocks); + } + + if (level > 0) { + vae_upsample_t *us = &vae->dec_upsample[up_idx++]; + int new_h = cur_h * 2; + int new_w = cur_w * 2; + + size_t up_elems = (size_t)batch * ch_out * new_h * new_w; + if (cudaMalloc((void **)&d_next, up_elems * sizeof(float)) != cudaSuccess) CUDA_DECODE_FAIL; + if (!flux_cuda_upsample_nearest2x_nchw_device(d_next, d_x, ch_out, cur_h, cur_w)) + CUDA_DECODE_FAIL; + cudaFree(d_x); + d_x = d_next; + d_next = NULL; + + if (cudaMalloc((void **)&d_next, up_elems * sizeof(float)) != cudaSuccess) CUDA_DECODE_FAIL; + if (!vae_conv2d_cuda_tiled(d_next, d_x, us->conv_weight, us->conv_bias, + ch_out, ch_out, new_h, new_w, 3, 3, 1, 1, &s)) + CUDA_DECODE_FAIL; + cudaFree(d_x); + d_x = d_next; + d_next = NULL; + + cur_h = new_h; + cur_w = new_w; + } + } + + /* Output: norm -> swish -> conv_out */ + size_t final_elems = (size_t)batch * vae->base_channels * cur_h * cur_w; + if (!vae_cuda_ensure_f32(&s.d_tmp1, &s.tmp1_cap, final_elems)) CUDA_DECODE_FAIL; + if (!vae_cuda_group_norm(s.d_tmp1, d_x, vae->dec_norm_out_weight, vae->dec_norm_out_bias, + batch, vae->base_channels, cur_h, cur_w, + vae->num_groups, vae->eps, &s)) CUDA_DECODE_FAIL; + if (!flux_cuda_silu_device(s.d_tmp1, (int)final_elems)) CUDA_DECODE_FAIL; + + size_t rgb_elems = (size_t)batch * 3 * cur_h * cur_w; + if (cudaMalloc((void **)&d_next, rgb_elems * sizeof(float)) != cudaSuccess) CUDA_DECODE_FAIL; + if (!vae_conv2d_cuda_tiled(d_next, s.d_tmp1, vae->dec_conv_out_weight, vae->dec_conv_out_bias, + vae->base_channels, 3, cur_h, cur_w, 3, 3, 1, 1, &s)) + CUDA_DECODE_FAIL; + cudaFree(d_x); + d_x = d_next; + d_next = NULL; + + float *rgb = (float *)malloc(rgb_elems * sizeof(float)); + if (!rgb) CUDA_DECODE_FAIL; + if (cudaMemcpy(rgb, d_x, rgb_elems * sizeof(float), cudaMemcpyDeviceToHost) != cudaSuccess) { + free(rgb); + CUDA_DECODE_FAIL; + } + + flux_image *img = flux_image_create(cur_w, cur_h, 3); + if (!img) { + free(rgb); + CUDA_DECODE_FAIL; + } + + for (int y = 0; y < cur_h; y++) { + for (int x = 0; x < cur_w; x++) { + for (int c = 0; c < 3; c++) { + float val = rgb[c * cur_h * cur_w + y * cur_w + x]; + val = (val + 1.0f) * 0.5f * 255.0f; + if (val < 0.0f) val = 0.0f; + if (val > 255.0f) val = 255.0f; + img->data[(y * cur_w + x) * 3 + c] = (unsigned char)(val + 0.5f); + } + } + } + + free(rgb); + cudaFree(d_x); + vae_cuda_scratch_free(&s); + return img; + +#undef CUDA_DECODE_FAIL +} +#endif /* USE_CUDA */ + /* ======================================================================== * Decoder Forward Pass * ======================================================================== */ @@ -698,6 +1182,14 @@ flux_image *flux_vae_decode(flux_vae_t *vae, const float *latent, } #endif +#ifdef USE_CUDA + if (!getenv("FLUX_CUDA_NO_VAE_RESIDENT")) { + flux_image *gpu_result = vae_decode_cuda(vae, latent, batch, latent_h, latent_w); + if (gpu_result) return gpu_result; + /* Fall through to CPU path on failure */ + } +#endif + /* * Decoder path: * [B, 128, H/16, W/16] @@ -754,7 +1246,7 @@ flux_image *flux_vae_decode(flux_vae_t *vae, const float *latent, batch, cur_h, cur_w, vae->num_groups, vae->eps); if (flux_vae_progress_callback) flux_vae_progress_callback(progress++, total_blocks); - if (attnblock_forward(x, work, &vae->dec_mid_attn, vae->work3, + if (attnblock_forward(vae, x, work, &vae->dec_mid_attn, vae->work3, batch, cur_h, cur_w, vae->num_groups, vae->eps) < 0) { return NULL; /* OOM in attention */ } @@ -1098,6 +1590,11 @@ void flux_vae_free(flux_vae_t *vae) { free(vae->work1); free(vae->work2); free(vae->work3); + free(vae->attn_q_t); + free(vae->attn_k_t); + free(vae->attn_v_t); + free(vae->attn_o_t); + free(vae->attn_scores); free(vae); } diff --git a/main.c b/main.c index b915585..c8f91f3 100644 --- a/main.c +++ b/main.c @@ -30,25 +30,11 @@ #include #include #include -#include #ifdef USE_METAL #include "flux_metal.h" #endif -#ifdef USE_BLAS -#ifdef __APPLE__ -#include -#else -/* OpenBLAS introspection functions */ -extern int openblas_get_num_threads(void); -extern int openblas_get_num_procs(void); -extern char *openblas_get_corename(void); -extern char *openblas_get_config(void); -extern void openblas_set_num_threads(int num_threads); -#endif -#endif - /* ======================================================================== * Verbosity Levels * ======================================================================== */ @@ -216,7 +202,7 @@ static double timer_end(void) { #define MAX_INPUT_IMAGES 16 static void print_usage(const char *prog) { - fprintf(stderr, "FLUX.2 klein - Pure C Image Generation\n\n"); + fprintf(stderr, "FLUX.2 klein 4B - Pure C Image Generation\n\n"); fprintf(stderr, "Usage: %s [options]\n\n", prog); fprintf(stderr, "Required:\n"); fprintf(stderr, " -d, --dir PATH Path to model directory\n"); @@ -239,15 +225,15 @@ static void print_usage(const char *prog) { fprintf(stderr, "Output options:\n"); fprintf(stderr, " -q, --quiet Silent mode, no output\n"); fprintf(stderr, " -v, --verbose Detailed output\n"); - fprintf(stderr, " --show Display image in terminal (auto-detects Kitty/Ghostty/iTerm2/WezTerm/Konsole)\n"); + fprintf(stderr, " --show Display image in terminal (auto-detects Kitty/Ghostty/iTerm2/Konsole)\n"); fprintf(stderr, " --show-steps Display each denoising step (slower)\n"); fprintf(stderr, " --zoom N Terminal image zoom factor (default: 2 for Retina)\n\n"); fprintf(stderr, "Other options:\n"); fprintf(stderr, " -e, --embeddings PATH Load pre-computed text embeddings\n"); fprintf(stderr, " -m, --mmap Use memory-mapped weights (default, fastest on MPS)\n"); fprintf(stderr, " --no-mmap Disable mmap, load all weights upfront\n"); - fprintf(stderr, " --no-license-info Suppress non-commercial license warning\n"); - fprintf(stderr, " --blas-threads N Set number of BLAS threads (OpenBLAS only)\n"); + fprintf(stderr, " -R, --server Persistent stdin server mode (one prompt per line)\n"); + fprintf(stderr, " --overlap-preload Overlap transformer load with text encode (optional)\n"); fprintf(stderr, " -h, --help Show this help\n\n"); fprintf(stderr, "Examples:\n"); fprintf(stderr, " %s -d model/ -p \"a cat on a rainbow\" -o cat.png\n", prog); @@ -255,6 +241,99 @@ static void print_usage(const char *prog) { fprintf(stderr, " %s -d model/ -p \"combine them\" -i car.png -i beach.png -o result.png\n", prog); } +static void server_output_path(char *dst, size_t dst_size, + const char *pattern, int index) { + if (!pattern || !pattern[0]) { + snprintf(dst, dst_size, "server-%04d.png", index); + return; + } + if (strstr(pattern, "%d")) { + snprintf(dst, dst_size, pattern, index); + return; + } + + const char *dot = strrchr(pattern, '.'); + if (dot && dot != pattern) { + int base_len = (int)(dot - pattern); + snprintf(dst, dst_size, "%.*s-%04d%s", base_len, pattern, index, dot); + } else { + snprintf(dst, dst_size, "%s-%04d.png", pattern, index); + } +} + +static int run_server_mode(flux_ctx *ctx, + const flux_params *base_params, + const char *output_pattern, + int show_image, + term_graphics_proto graphics_proto) { + char *line = NULL; + size_t line_cap = 0; + int req_id = 0; + + if (output_level >= OUTPUT_NORMAL) { + fprintf(stderr, "Server mode ready. Enter one prompt per line.\n"); + fprintf(stderr, "Use TAB to specify output path: prompt/tmp/out.png\n"); + fprintf(stderr, "Type /quit to exit.\n"); + } + + while (getline(&line, &line_cap, stdin) != -1) { + size_t len = strlen(line); + while (len > 0 && (line[len - 1] == '\n' || line[len - 1] == '\r')) { + line[--len] = '\0'; + } + if (len == 0) continue; + if (strcmp(line, "/quit") == 0 || strcmp(line, "quit") == 0 || strcmp(line, "exit") == 0) { + break; + } + + req_id++; + char *tab = strchr(line, '\t'); + char *prompt = line; + char out_path[4096]; + if (tab) { + *tab = '\0'; + if (strstr(tab + 1, "%d")) { + snprintf(out_path, sizeof(out_path), tab + 1, req_id); + } else { + snprintf(out_path, sizeof(out_path), "%s", tab + 1); + } + } else { + server_output_path(out_path, sizeof(out_path), output_pattern, req_id); + } + + flux_params p = *base_params; + if (p.seed >= 0) p.seed += req_id - 1; + + struct timeval t0, t1; + gettimeofday(&t0, NULL); + flux_image *img = flux_generate(ctx, prompt, &p); + gettimeofday(&t1, NULL); + double elapsed = (t1.tv_sec - t0.tv_sec) + (t1.tv_usec - t0.tv_usec) / 1000000.0; + + if (!img) { + fprintf(stdout, "ERR %d %s\n", req_id, flux_get_error()); + fflush(stdout); + continue; + } + + int64_t out_seed = (p.seed >= 0) ? p.seed : (int64_t)time(NULL); + if (flux_image_save_with_seed(img, out_path, out_seed) != 0) { + fprintf(stdout, "ERR %d save_failed\n", req_id); + fflush(stdout); + flux_image_free(img); + continue; + } + + if (show_image) terminal_display_png(out_path, graphics_proto); + fprintf(stdout, "OK %d %s %.3f\n", req_id, out_path, elapsed); + fflush(stdout); + flux_image_free(img); + } + + free(line); + return 0; +} + /* ======================================================================== * Main * ======================================================================== */ @@ -262,6 +341,12 @@ static void print_usage(const char *prog) { int main(int argc, char *argv[]) { #ifdef USE_METAL flux_metal_init(); +#elif defined(USE_CUDA) + fprintf(stderr, "CUDA: cuBLAS GPU acceleration enabled\n"); +#elif defined(USE_BLAS) + fprintf(stderr, "BLAS: CPU acceleration enabled (Accelerate/OpenBLAS)\n"); +#else + fprintf(stderr, "Generic: Pure C backend (no acceleration)\n"); #endif /* Command line options */ @@ -290,9 +375,9 @@ int main(int argc, char *argv[]) { {"linear", no_argument, 0, 'L'}, {"power", no_argument, 0, 256}, {"power-alpha",required_argument, 0, 257}, + {"overlap-preload", no_argument, 0, 258}, + {"server", no_argument, 0, 'R'}, {"debug-py", no_argument, 0, 'D'}, - {"no-license-info", no_argument, 0, 258}, - {"blas-threads",required_argument, 0, 259}, {0, 0, 0, 0} }; @@ -320,12 +405,12 @@ int main(int argc, char *argv[]) { int show_steps = 0; int debug_py = 0; int force_base = 0; - int no_license_info = 0; - int blas_threads = 0; (void)blas_threads; + int server_mode = 0; + int overlap_preload = 0; term_graphics_proto graphics_proto = detect_terminal_graphics(); int opt; - while ((opt = getopt_long(argc, argv, "d:p:o:W:H:s:g:S:i:t:e:n:qvhVmMD", + while ((opt = getopt_long(argc, argv, "d:p:o:W:H:s:g:S:i:t:e:n:qvhVmMDR", long_options, NULL)) != -1) { switch (opt) { case 'd': model_dir = optarg; break; @@ -349,7 +434,7 @@ int main(int argc, char *argv[]) { case 'v': output_level = OUTPUT_VERBOSE; flux_verbose = 1; break; case 'h': print_usage(argv[0]); return 0; case 'V': - fprintf(stderr, "FLUX.2 klein v1.0.0\n"); + fprintf(stderr, "FLUX.2 klein 4B v1.0.0\n"); return 0; case 'm': use_mmap = 1; break; case 'M': use_mmap = 0; break; @@ -360,53 +445,15 @@ int main(int argc, char *argv[]) { case 'L': params.linear_schedule = 1; break; case 256: params.power_schedule = 1; break; case 257: params.power_alpha = atof(optarg); params.power_schedule = 1; break; - case 258: no_license_info = 1; break; + case 258: overlap_preload = 1; break; + case 'R': server_mode = 1; break; case 'D': debug_py = 1; break; - case 259: blas_threads = atoi(optarg); break; default: print_usage(argv[0]); return 1; } } - /* BLAS: apply thread setting regardless of quiet mode */ -#if defined(USE_BLAS) && !defined(USE_METAL) && !defined(__APPLE__) - if (blas_threads > 0) openblas_set_num_threads(blas_threads); -#endif - - /* Backend banner (suppressed by --quiet) */ - if (output_level != OUTPUT_QUIET) { -#ifdef USE_METAL - if (flux_metal_available()) { - long ncpu = sysconf(_SC_NPROCESSORS_ONLN); - char cpu_brand[128] = "Apple Silicon"; - size_t len = sizeof(cpu_brand); - sysctlbyname("machdep.cpu.brand_string", cpu_brand, &len, NULL, 0); - fprintf(stderr, "MPS: Metal GPU | %s | %ld cores\n", cpu_brand, ncpu); - } -#elif defined(USE_BLAS) -#ifdef __APPLE__ - { - char cpu_brand[128] = "Apple Silicon"; - size_t len = sizeof(cpu_brand); - sysctlbyname("machdep.cpu.brand_string", cpu_brand, &len, NULL, 0); - long ncpu = sysconf(_SC_NPROCESSORS_ONLN); - fprintf(stderr, "BLAS: Accelerate | %s | %ld cores\n", cpu_brand, ncpu); - if (blas_threads > 0) - fprintf(stderr, "Warning: --blas-threads ignored (Accelerate manages threading automatically)\n"); - } -#else - fprintf(stderr, "BLAS: OpenBLAS | %s | %d threads / %d procs\n", - openblas_get_corename(), - openblas_get_num_threads(), - openblas_get_num_procs()); - fprintf(stderr, " %s\n", openblas_get_config()); -#endif -#else - fprintf(stderr, "Generic: Pure C backend (no acceleration)\n"); -#endif - } - /* Validate required arguments */ if (!model_dir) { fprintf(stderr, "Error: Model directory (-d) is required\n\n"); @@ -414,10 +461,10 @@ int main(int argc, char *argv[]) { return 1; } - /* Interactive mode: -d provided but no -p, -e, -o, or --debug-py */ - int interactive_mode = (!prompt && !embeddings_path && !output_path && !debug_py); + /* Interactive mode: -d provided but no request options and not server mode. */ + int interactive_mode = (!server_mode && !prompt && !embeddings_path && !output_path && !debug_py); - if (!interactive_mode) { + if (!interactive_mode && !server_mode) { if (!prompt && !embeddings_path && !debug_py) { fprintf(stderr, "Error: Prompt (-p) or embeddings file (-e) is required\n\n"); print_usage(argv[0]); @@ -430,6 +477,10 @@ int main(int argc, char *argv[]) { } } + if (overlap_preload) { + setenv("FLUX_OVERLAP_PRELOAD", "1", 1); + } + /* Validate parameters */ if (params.width < 64 || params.width > 4096) { fprintf(stderr, "Error: Width must be between 64 and 4096\n"); @@ -455,11 +506,12 @@ int main(int argc, char *argv[]) { LOG_NORMAL("Seed: %lld\n", (long long)actual_seed); /* Verbose header */ - LOG_VERBOSE("FLUX.2 klein Image Generator\n"); + LOG_VERBOSE("FLUX.2 klein 4B Image Generator\n"); LOG_VERBOSE("================================\n"); LOG_VERBOSE("Model: %s\n", model_dir); if (prompt) LOG_VERBOSE("Prompt: %s\n", prompt); - LOG_VERBOSE("Output: %s\n", output_path); + if (output_path) LOG_VERBOSE("Output: %s\n", output_path); + else if (server_mode) LOG_VERBOSE("Output: server pattern (auto)\n"); LOG_VERBOSE("Size: %dx%d\n", params.width, params.height); LOG_VERBOSE("Steps: %d\n", params.num_steps); if (num_inputs > 0) { @@ -504,17 +556,6 @@ int main(int argc, char *argv[]) { LOG_NORMAL(" done (%.1fs)\n", load_time); LOG_NORMAL("Model: %s\n", flux_model_info(ctx)); - /* Non-commercial license warning for 9B model */ - if (flux_is_non_commercial(ctx) && !no_license_info - && output_level != OUTPUT_QUIET) { - fprintf(stderr, - "\nNOTE: This model is released under a NON COMMERCIAL LICENSE.\n" - "The output can only be used under the terms of the\n" - "FLUX non-commercial license:\n" - "https://huggingface.co/black-forest-labs/FLUX.2-klein-9B/blob/main/LICENSE.md\n" - "(use --no-license-info to suppress this message)\n\n"); - } - /* Interactive mode: start REPL */ if (interactive_mode) { int rc = flux_cli_run(ctx, model_dir); @@ -530,13 +571,25 @@ int main(int argc, char *argv[]) { /* Set up step image callback if requested */ if (show_steps) { if (graphics_proto == TERM_PROTO_NONE) { - fprintf(stderr, "Warning: --show-steps requires a supported terminal (Kitty, Ghostty, iTerm2, WezTerm, or Konsole)\n"); + fprintf(stderr, "Warning: --show-steps requires a supported terminal (Kitty, Ghostty, iTerm2, or Konsole)\n"); } else { cli_graphics_proto = graphics_proto; flux_set_step_image_callback(ctx, cli_step_image_callback); } } + /* Persistent stdin server mode */ + if (server_mode) { + int rc = run_server_mode(ctx, ¶ms, output_path, show_image, graphics_proto); + cli_finish_progress(); + if (show_steps) flux_set_step_image_callback(ctx, NULL); + flux_free(ctx); +#ifdef USE_METAL + flux_metal_cleanup(); +#endif + return rc; + } + /* Generate image */ flux_image *output = NULL; struct timeval total_start_tv; From e87d9a4466cfd9622bd7af1584cf5202f5a8ede9 Mon Sep 17 00:00:00 2001 From: Todd Fisher Date: Mon, 9 Feb 2026 09:52:15 -0500 Subject: [PATCH 2/2] keep model loaded in server --- flux.c | 54 +++++++++++++++++++++++++++++++++++++----------------- flux.h | 7 +++++++ main.c | 6 ++++++ 3 files changed, 50 insertions(+), 17 deletions(-) diff --git a/flux.c b/flux.c index 351227a..d6fc9cb 100644 --- a/flux.c +++ b/flux.c @@ -155,6 +155,7 @@ struct flux_ctx { /* Memory mode */ int use_mmap; /* Use mmap for text encoder (lower memory, slower) */ + int keep_text_encoder_loaded; /* Keep Qwen resident across generations */ }; /* Global error message */ @@ -314,6 +315,10 @@ void flux_set_mmap(flux_ctx *ctx, int enable) { if (ctx) ctx->use_mmap = enable; } +void flux_set_keep_text_encoder(flux_ctx *ctx, int enable) { + if (ctx) ctx->keep_text_encoder_loaded = enable ? 1 : 0; +} + int flux_is_distilled(flux_ctx *ctx) { return ctx ? ctx->is_distilled : 1; } @@ -374,6 +379,32 @@ static int flux_load_transformer_if_needed(flux_ctx *ctx) { return flux_load_transformer_internal(ctx, 1); } +/* Load transformer with server-friendly fallback: + * keep text encoder resident when requested, but retry after releasing it + * if peak memory is too high during first transformer load. */ +static int flux_ensure_transformer_ready(flux_ctx *ctx) { + if (ctx->transformer) return 1; + + if (!ctx->keep_text_encoder_loaded && ctx->qwen3_encoder) { + flux_release_text_encoder(ctx); + return flux_load_transformer_if_needed(ctx); + } + + if (flux_load_transformer_if_needed(ctx)) { + return 1; + } + + if (ctx->keep_text_encoder_loaded && ctx->qwen3_encoder) { + fprintf(stderr, + "Warning: Transformer load failed with persistent text encoder; " + "retrying after releasing text encoder\n"); + flux_release_text_encoder(ctx); + return flux_load_transformer_if_needed(ctx); + } + + return 0; +} + #if defined(__unix__) || defined(__APPLE__) typedef struct { flux_ctx *ctx; @@ -527,12 +558,8 @@ flux_image *flux_generate(flux_ctx *ctx, const char *prompt, /* Ensure any async preload attempt is complete before load checks/fallback. */ flux_transformer_preload_join(&tf_preload); - /* Release text encoder only before first transformer load. - * Once transformer is loaded, keeping encoder avoids reload cost on later calls. */ - if (!ctx->transformer) flux_release_text_encoder(ctx); - - /* Load transformer now (after text encoder is freed to reduce peak memory) */ - if (!flux_load_transformer_if_needed(ctx)) { + /* Ensure transformer is ready (with optional low-memory fallback). */ + if (!flux_ensure_transformer_ready(ctx)) { free(text_emb); free(text_emb_uncond); return NULL; @@ -936,12 +963,8 @@ flux_image *flux_img2img(flux_ctx *ctx, const char *prompt, flux_transformer_preload_join(&tf_preload); - /* Release text encoder only before first transformer load. - * Once transformer is loaded, keeping encoder avoids reload cost on later calls. */ - if (!ctx->transformer) flux_release_text_encoder(ctx); - - /* Load transformer now (after text encoder is freed to reduce peak memory) */ - if (!flux_load_transformer_if_needed(ctx)) { + /* Ensure transformer is ready (with optional low-memory fallback). */ + if (!flux_ensure_transformer_ready(ctx)) { free(text_emb); free(text_emb_uncond); if (resized) flux_image_free(resized); @@ -1126,11 +1149,8 @@ flux_image *flux_multiref(flux_ctx *ctx, const char *prompt, flux_transformer_preload_join(&tf_preload); - /* Release text encoder only before first transformer load. - * Once transformer is loaded, keeping encoder avoids reload cost on later calls. */ - if (!ctx->transformer) flux_release_text_encoder(ctx); - - if (!flux_load_transformer_if_needed(ctx)) { + /* Ensure transformer is ready (with optional low-memory fallback). */ + if (!flux_ensure_transformer_ready(ctx)) { free(text_emb); free(text_emb_uncond); return NULL; diff --git a/flux.h b/flux.h index 6fb9213..fd3ef0b 100644 --- a/flux.h +++ b/flux.h @@ -120,6 +120,13 @@ void flux_release_text_encoder(flux_ctx *ctx); */ void flux_set_mmap(flux_ctx *ctx, int enable); +/* + * Keep the text encoder loaded across generations. + * Useful for persistent server mode to avoid repeated reload cost. + * Default is disabled to minimize peak memory during first transformer load. + */ +void flux_set_keep_text_encoder(flux_ctx *ctx, int enable); + /* * Check if model is distilled (4-step) or base (50-step with CFG). * Returns 1 for distilled, 0 for base. diff --git a/main.c b/main.c index c8f91f3..5498b4e 100644 --- a/main.c +++ b/main.c @@ -544,6 +544,12 @@ int main(int argc, char *argv[]) { flux_set_base_mode(ctx); } + /* In persistent stdin server mode, keep text encoder resident so + * per-request latency does not include encoder reload. */ + if (server_mode) { + flux_set_keep_text_encoder(ctx, 1); + } + /* Resolve auto-parameters now that we know the model type */ if (!steps_set || params.num_steps <= 0) { params.num_steps = flux_is_distilled(ctx) ? 4 : 50;