diff --git a/.gitignore b/.gitignore
index 364b379..6baf500 100644
--- a/.gitignore
+++ b/.gitignore
@@ -38,9 +38,6 @@ MANIFEST
*.so
.cache/
-# Development mode soft link
-tilert
-
!src/lib/
!include/lib/
diff --git a/Dockerfile b/Dockerfile
index fd1f7ec..b31f15b 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,46 +1,128 @@
-FROM pytorch/manylinux2_28-builder:cuda12.9-main
+# TileRT release builder / runtime image
+#
+# Every dep version is pinned to the validated set. Don't bump anything
+# without re-running the full release pipeline (build wheel → fresh container
+# → pip install wheel → pytest on B200 GPUs).
+#
+# Especially: transformers MUST be 4.46.3. The 5.x branch is not backward
+# compatible with TileRT's tokenizer/model loading paths.
+#
+# Build:
+# docker build -t tileai/tilert:cu132-v0.1.4 .
+# Pull pre-built:
+# docker pull tileai/tilert:cu132-v0.1.4
+# Use:
+# docker run --rm --gpus all -v $PWD:/workspace -w /workspace \
+# tileai/tilert:cu132-v0.1.4 make wheel BUILD_TYPE=Release
+
+FROM pytorch/manylinux2_28-builder:cuda13.2-main
SHELL ["/bin/bash", "-c"]
-RUN yum update -y && \
- yum install -y epel-release yum-utils vim && \
- (yum config-manager --set-enabled powertools || \
- yum config-manager --set-enabled crb || true) && \
- yum --enablerepo=epel install -y glog glog-devel && \
- yum clean all && \
- rm -rf /var/cache/yum /var/tmp/* /tmp/*
-
-RUN conda init bash && \
- . /opt/conda/etc/profile.d/conda.sh && \
- conda create -y -n tilert python=3.12 && \
- conda activate tilert && \
- conda clean -afy && \
- rm -rf /opt/conda/pkgs/* /opt/conda/conda-meta/*.json.bak
+# ── System packages (glog: TileRT runtime dep; zstd: image transport) ────────
+RUN yum install -y --setopt=install_weak_deps=False \
+ epel-release yum-utils vim && \
+ (yum config-manager --set-enabled powertools 2>/dev/null || \
+ yum config-manager --set-enabled crb 2>/dev/null || true) && \
+ yum --enablerepo=epel install -y --setopt=install_weak_deps=False \
+ glog glog-devel zstd && \
+ rpm -e --nodeps cmake 2>/dev/null || true && \
+ yum clean all && rm -rf /var/cache/yum /var/tmp/* /tmp/*
-COPY requirements.txt requirements-dev.txt /tmp/
+# ── Conda env: python 3.12, named "tilert" ───────────────────────────────────
RUN . /opt/conda/etc/profile.d/conda.sh && \
- conda activate tilert && \
- pip install --no-cache-dir -r /tmp/requirements-dev.txt && \
- pip cache purge && \
- rm -rf /tmp/requirements*.txt /root/.cache/pip /root/.cache/* && \
+ conda create -y -n tilert python=3.12.9 && \
+ conda clean -afy && rm -rf /opt/conda/pkgs/*
+
+# ── Pinned lock set (resolved 2026-05-27 against torch 2.11.0+cu130 +
+# transformers 4.46.3 on python 3.12 / manylinux_2_28) ─────────────────────
+#
+# torch's METADATA transitively pins the nvidia-* cu13 runtime packages
+# (cublas==13.1.0.3, cudnn-cu13==9.19.0.56, nccl-cu13==2.28.9, etc.) — those
+# are NOT re-pinned here on purpose, so any patch bump in PyTorch's cu130
+# release line flows through.
+ARG PIP_INDEX_URL=https://download.pytorch.org/whl/cu130
+ARG PIP_EXTRA_INDEX_URL=https://pypi.org/simple
+RUN . /opt/conda/etc/profile.d/conda.sh && conda activate tilert && \
+ pip install --no-cache-dir \
+ --index-url "$PIP_INDEX_URL" \
+ --extra-index-url "$PIP_EXTRA_INDEX_URL" \
+ --upgrade pip==25.3 && \
+ pip install --no-cache-dir \
+ --index-url "$PIP_INDEX_URL" \
+ --extra-index-url "$PIP_EXTRA_INDEX_URL" \
+ "torch==2.11.0+cu130" \
+ "triton==3.6.0" \
+ "transformers==4.46.3" \
+ "tokenizers==0.20.3" \
+ "huggingface_hub==0.35.3" \
+ "hf_xet==1.1.10" \
+ "safetensors==0.6.2" \
+ "regex==2025.9.18" \
+ "requests==2.32.3" \
+ "charset_normalizer==3.3.2" \
+ "idna==3.7" \
+ "urllib3==2.3.0" \
+ "certifi==2026.2.25" \
+ "packaging==24.2" \
+ "tqdm==4.67.1" \
+ "pyyaml==6.0.2" \
+ "numpy==2.3.2" \
+ "einops==0.8.1" \
+ "filelock==3.29.0" \
+ "fsspec==2026.4.0" \
+ "jinja2==3.1.6" \
+ "MarkupSafe==3.0.3" \
+ "networkx==3.6.1" \
+ "sympy==1.14.0" \
+ "mpmath==1.3.0" \
+ "typing_extensions==4.15.0" \
+ "setuptools==81.0.0" \
+ "importlib_metadata==8.7.1" \
+ "zipp==3.23.0" \
+ "scikit-build-core==0.12.2" \
+ "setuptools-scm==9.2.2" \
+ "vcs-versioning==1.1.1" \
+ "pathspec==1.1.1" \
+ "ninja==1.13.0" \
+ "cmake==4.1.2" \
+ "pytest==8.4.1" \
+ "pytest-cov==7.1.0" \
+ "pluggy==1.6.0" \
+ "iniconfig==2.3.0" \
+ "pygments==2.20.0" \
+ "tomli==2.4.1" \
+ "coverage==7.10.7" \
+ "exceptiongroup==1.3.1" && \
+ python -c 'import torch, triton, transformers, tokenizers; assert torch.__version__ == "2.11.0+cu130", torch.__version__; assert torch.version.cuda.startswith("13"), torch.version.cuda; assert triton.__version__ == "3.6.0", triton.__version__; assert transformers.__version__ == "4.46.3", transformers.__version__; assert tokenizers.__version__ == "0.20.3", tokenizers.__version__; print("torch", torch.__version__, "cuda", torch.version.cuda, "| triton", triton.__version__, "| transformers", transformers.__version__, "| tokenizers", tokenizers.__version__, "OK")' && \
+ pip cache purge && rm -rf /root/.cache/pip /root/.cache/* && \
conda clean -afy && \
find /opt/conda -type f -name "*.pyc" -delete && \
find /opt/conda -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
-RUN echo "alias cls='clear'" >> ~/.bashrc && \
- echo "alias ll='ls -l'" >> ~/.bashrc && \
- echo "alias la='ls -a'" >> ~/.bashrc && \
- echo "alias vi='vim'" >> ~/.bashrc && \
- echo "alias grep='grep --color=auto'" >> ~/.bashrc && \
- echo "export PATH=\"/opt/conda/bin:\$PATH\"" >> ~/.bashrc && \
- echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \
- echo "conda activate tilert" >> ~/.bashrc && \
- echo '#!/bin/bash' > /usr/local/bin/entrypoint.sh && \
- echo 'export PATH="/opt/conda/bin:$PATH"' >> /usr/local/bin/entrypoint.sh && \
- echo '. /opt/conda/etc/profile.d/conda.sh' >> /usr/local/bin/entrypoint.sh && \
- echo 'conda activate tilert' >> /usr/local/bin/entrypoint.sh && \
- echo 'exec "$@"' >> /usr/local/bin/entrypoint.sh && \
+# ── CUDA arch (Blackwell sm_100) + scikit-build pass-through ─────────────────
+ENV TORCH_CUDA_ARCH_LIST="10.0" \
+ CUDAARCHS="100" \
+ CMAKE_ARGS="-DUSER_CUDA_ARCH_LIST=10.0" \
+ SKBUILD_CMAKE_DEFINE="USER_CUDA_ARCH_LIST=10.0" \
+ CMAKE_BUILD_PARALLEL_LEVEL=16 \
+ PATH="/opt/conda/envs/tilert/bin:/opt/conda/bin:${PATH}"
+
+# ── Shell activation + entrypoint ─────────────────────────────────────────────
+RUN { echo 'export PATH=/opt/conda/envs/tilert/bin:/opt/conda/bin:$PATH'; \
+ echo '. /opt/conda/etc/profile.d/conda.sh'; \
+ echo 'conda activate tilert 2>/dev/null || true'; \
+ } >> /etc/bashrc && \
+ printf '%s\n' \
+ '#!/bin/bash' \
+ 'set -e' \
+ '. /opt/conda/etc/profile.d/conda.sh' \
+ 'conda activate tilert' \
+ 'exec "$@"' \
+ > /usr/local/bin/entrypoint.sh && \
chmod +x /usr/local/bin/entrypoint.sh
+WORKDIR /workspace
+
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
CMD ["/bin/bash"]
diff --git a/README.md b/README.md
index fe5d3df..5cb3d7c 100644
--- a/README.md
+++ b/README.md
@@ -20,7 +20,9 @@ ______________________________________________________________________
## 📰 News
-- 🏭 **2026-05-22 · [TileRT in Production](https://www.tilert.ai/blog/speed-as-the-next-scaling-law-zh.html)**. [GLM-5.1-highspeed](https://docs.bigmodel.cn/cn/guide/models/text/glm-5.1-highspeed) is now live on Z.ai, powered by TileRT — from experimental prototype to real production. TileRT-v0.1.4 is coming soon.
+- 🚀 **2026-06-01 · [v0.1.4](https://github.com/tile-ai/TileRT/releases/tag/v0.1.4) Released**. A major performance upgrade for both DeepSeek-V3.2 and GLM-5, with model quality unchanged. See the benchmark charts for details.
+
+- 🏭 **2026-05-22 · [TileRT in Production](https://www.tilert.ai/blog/speed-as-the-next-scaling-law-zh.html)**. [GLM-5.1-highspeed](https://docs.bigmodel.cn/cn/guide/models/text/glm-5.1-highspeed) is now live on Z.ai, powered by TileRT — from experimental prototype to real production.
- :fire: **2026-02-14 · [Try the Online Demo](https://www.tilert.ai/)**. Our online demo is now live! Experience ultra-low-latency inference with **GLM-5** and **DeepSeek-V3.2**. [Try it now !](https://www.tilert.ai)
@@ -49,68 +51,72 @@ To achieve this, TileRT introduces a **tile-level runtime engine**. Leveraging a
The project is actively evolving, and the underlying compiler techniques will be gradually shared with the community as they are integrated into **TileLang** and **TileScale**.
+
+
+
+ GLM-5.1-FP8 token generation speed on 8× NVIDIA B200 with TileRT v0.1.4. Output length 1K, input length 1K–192K. Bars compare TileRT without MTP, with MTP at average acceptance length 3.2, and the peak under best-case MTP acceptance.
+
+
______________________________________________________________________
## Installation
-- [Prerequisites](#prerequisites)
-- [Python Package Installation](#python-package-installation)
-
-### Prerequisites
-
-Before installing TileRT, ensure your environment meets the following requirements:
-
-**Hardware Requirements**
-
-- 8× NVIDIA B200 GPUs
-
-**Operating System**
-
-- Linux x86_64 (Ubuntu 20.04 or later recommended)
-
-**Python Version**
-
-- Python 3.11 – 3.12
- *(The wheel package is built and tested against these versions.)*
-
-**PyTorch Build**
+> \[!IMPORTANT\]
+> TileRT v0.1.4 is distributed as a **pre-built binary wheel**. The wheel is linked against the exact ABI of the versions listed below. Other combinations of Python, CUDA, or PyTorch versions are **untested and not guaranteed to work** — please reproduce this environment for a supported setup.
-- PyTorch wheels compiled for CUDA 12.8 or 12.9
- *(Must match the CUDA driver/runtime version required for B200 GPUs.)*
+### Build environment of the v0.1.4 wheel
-### Python Package Installation
+The official `tilert==0.1.4` wheel on PyPI was compiled against the following stack. Treat these as **hard requirements**, not lower bounds.
-> \[!IMPORTANT\]
-> **Disclaimer**: TileRT is an experimental project. The current pre-built package supports the 8-GPU B200 setup. For the most reliable experience, we strongly recommend installing the package within the provided Docker image.
+| Component | Pinned version |
+| ---------------- | --------------------------------------------------- |
+| GPU | 8× NVIDIA **B200** |
+| NVIDIA driver | Supports **CUDA 13.2** runtime |
+| Operating System | Linux **x86_64**, glibc **≥ 2.28** (manylinux_2_28) |
+| Python | **3.12** |
+| PyTorch | **`torch==2.11.0+cu130`** |
+| `transformers` | **`4.46.3`** |
+| `tokenizers` | **`0.20.3`** |
-The recommended installation method is using the pre-configured Docker image, which includes all necessary dependencies.
+### Recommended: pre-built Docker image
-**Step 1: Pull the Docker image**
+The pinned build environment above is preinstalled in our official image
+— this is the **recommended way to run v0.1.4** and avoids any version
+drift on the host. The image is mirrored to two registries; pull from
+whichever is reachable:
```bash
-docker pull tileai/tilert:v0.1.0
+# GitHub Container Registry
+docker pull ghcr.io/tile-ai/tilert:cu132-latest
+
+# Docker Hub
+docker pull tileai/tilert:cu132-latest
```
-**Step 2: Launch a Docker container**
+Launch a container with all 8 B200 GPUs attached, then install the
+wheel inside:
```bash
-IMAGE_NAME="tileai/tilert:v0.1.0"
-WORKSPACE_PATH="/path/to/your/workspace" # Replace with your actual workspace path
+docker run --rm -it --gpus all --ipc=host \
+ -v "$PWD":/workspace -w /workspace \
+ ghcr.io/tile-ai/tilert:cu132-latest
-docker run --gpus all -it \
- -v $WORKSPACE_PATH:/workspace/ \
- $IMAGE_NAME
-```
+# Inside the container — install from PyPI:
+pip install tilert==0.1.4
-**Step 3: Install the TileRT package**
+# Or pin the exact wheel from the GitHub Release page directly
+# (same artifact, useful when PyPI is unreachable):
+pip install https://github.com/tile-ai/TileRT/releases/download/v0.1.4/tilert-0.1.4-cp312-cp312-manylinux_2_28_x86_64.whl
+```
-Once inside the container, install TileRT using pip:
+Verify the install:
```bash
-pip install tilert
+python -c "import tilert, torch; print('tilert', tilert.__version__, '/ torch', torch.__version__, '/ cuda', torch.version.cuda)"
+# Expected: tilert 0.1.4 / torch 2.11.0+cu130 / cuda 13.0
```
-You're now ready to use TileRT! Proceed to the [Getting Started](#getting-started) section to download model weights and run your first inference.
+Proceed to [Getting Started](#getting-started) to download and convert model weights.
## Getting Started
@@ -118,11 +124,15 @@ You're now ready to use TileRT! Proceed to the [Getting Started](#getting-starte
Starting from release v0.1.3, TileRT no longer requires downloading pre-converted weights from Hugging Face. Instead, you can download the official model weights directly from the model's source (e.g., Hugging Face), and then convert them using the weight converter script included with the latest TileRT release.
-### Step 2: Convert Weights Using `weight_converter.py`
+### Step 2: Shard Weights with `weight_converter`
-After downloading the official model weights, you can use the following command to convert them into a format compatible with TileRT:
+The converter ships inside the `tilert` wheel. It rewrites the official HF
+checkpoint into TileRT's per-device layout — 8 shards, one per B200, with
+keys suffixed `*_dev_{0..7}` and a fresh `model.safetensors.index.json`.
+The runtime loads these shards directly; the original checkpoint is no
+longer needed after conversion.
-For **DeepSeek-V3.2**, run:
+For **DeepSeek-V3.2**:
```bash
python -m tilert.models.preprocess.weight_converter \
@@ -131,9 +141,7 @@ python -m tilert.models.preprocess.weight_converter \
--save_dir "/path/to/DeepSeek-V3.2-TileRT"
```
-Replace `/path/to/DeepSeek-V3.2` with the directory where you've downloaded the model weights, and `/path/to/DeepSeek-V3.2-TileRT` with the directory where you'd like the converted weights to be saved.
-
-Similarly, for **GLM-5**, run:
+For **GLM-5**:
```bash
python -m tilert.models.preprocess.weight_converter \
@@ -142,40 +150,52 @@ python -m tilert.models.preprocess.weight_converter \
--save_dir "/path/to/GLM-5-FP8-TileRT"
```
-Replace `/path/to/GLM-5-FP8` with the directory containing the downloaded GLM-5 model weights, and `/path/to/GLM-5-FP8-TileRT` with the desired location for saving the converted weights.
+`--model_dir` is the directory of the downloaded HF checkpoint;
+`--save_dir` is where the sharded TileRT-format weights will land.
-### Step 3: Set the Converted Weights Directory
+### Step 3: Register the Sharded Weights Path
-Once the weights are converted, set the environment variable to point TileRT to the directory containing the converted weights:
+Either pass `--model-weights-dir ` on every `tilert.generate`
+invocation, or register the path once in `~/.tilert/config.toml` so the
+CLI picks it up automatically:
-```bash
-export MODEL_WEIGHTS_DIR= ... # converted weights
+```toml
+[weights]
+deepseek_v3_2 = "/path/to/DeepSeek-V3.2-TileRT"
+glm5 = "/path/to/GLM-5-FP8-TileRT"
```
-Now you're ready to use TileRT with the converted weights!
-
### Running the Generation Example
-After downloading the model weights, you can run the generation example within the Docker environment as follows:
+The simplest entry point is the bundled CLI. Pick `--model deepseek_v3_2`
+or `--model glm5`; weights resolve from `~/.tilert/config.toml` or from
+an explicit `--model-weights-dir`:
```bash
-MODEL_WEIGHTS_DIR="/path/to/tilert_weights"
-
-docker run --gpus all -it \
- -v $WORKSPACE_PATH:/workspace/ \
- -v $MODEL_WEIGHTS_DIR:$MODEL_WEIGHTS_MOUNT \
- tilert:v0.1.0
+python -m tilert.generate --model deepseek_v3_2 --max-new-tokens 1000
```
-Once inside the container, run the following Python script to perform text generation:
+> \[!NOTE\]
+> v0.1.4 ships **two independent backend libraries** (`libtilert_dsv32.so`
+> and `libtilert_glm5.so`) and loads exactly one per Python process via
+> `tilert.load_backend(model_type)`. Run DeepSeek-V3.2 and GLM-5 in
+> separate processes — they cannot coexist in a single interpreter.
+
+To drive generation programmatically, load the backend first, then build
+the matching generator:
```python
-from tilert.models.deepseek_v3_2.dsa_show_hands import ShowHandsGenerator
+import tilert
+from tilert.models.deepseek_v3_2.generator import DSAv32Generator
+from tilert.models.deepseek_v3_2.model_args import ModelArgs
-generator: ShowHandsGenerator = ShowHandsGenerator(
+tilert.load_backend("deepseek_v3_2")
+
+generator = DSAv32Generator(
+ model_args=ModelArgs(),
max_new_tokens=1000,
- model_weights_dir=MODEL_WEIGHTS_DIR,
- with_mtp=False, # Disable MTP
+ model_weights_dir="/path/to/DeepSeek-V3.2-TileRT",
+ with_mtp=False,
)
generator.from_pretrained()
@@ -193,6 +213,10 @@ print("Completion:")
completion = generator.generate(prompt)
```
+(For **GLM-5**, swap in `tilert.load_backend("glm5")` and
+`from tilert.models.glm_5.generator import GLM5Generator` with
+`ModelArgsGLM5`.)
+
For example, TileRT may generate:
@@ -210,17 +234,26 @@ This example demonstrates basic single-step autoregressive generation using the
### Running the Generation Example with Multi-Token Prediction (MTP)
-TileRT also supports Multi-Token Prediction (MTP), which allows the model to generate multiple tokens per forward pass and reduces sequential decoding depth.
+TileRT also supports Multi-Token Prediction (MTP), which allows the model to generate multiple tokens per forward pass and reduces sequential decoding depth. Enable it from the CLI with `--with-mtp`:
+
+```bash
+python -m tilert.generate --model deepseek_v3_2 --with-mtp --max-new-tokens 1000
+```
-To better illustrate MTP behavior, we use a longer prompt that encourages extended generation:
+Or programmatically, pass `with_mtp=True` to the generator:
```python
-from tilert.models.deepseek_v3_2.dsa_show_hands import ShowHandsGenerator
+import tilert
+from tilert.models.deepseek_v3_2.generator import DSAv32Generator
+from tilert.models.deepseek_v3_2.model_args import ModelArgs
+
+tilert.load_backend("deepseek_v3_2")
-generator: ShowHandsGenerator = ShowHandsGenerator(
+generator = DSAv32Generator(
+ model_args=ModelArgs(),
max_new_tokens=1000,
- model_weights_dir=MODEL_WEIGHTS_DIR,
- with_mtp=True, # Enable MTP
+ model_weights_dir="/path/to/DeepSeek-V3.2-TileRT",
+ with_mtp=True,
)
generator.from_pretrained()
prompt = "Tell me 10 jokes, keep them all under 100 words."
@@ -269,7 +302,7 @@ Of course! Here are 10 short jokes for you.
This example highlights how MTP enables TileRT to efficiently generate longer outputs by accepting multiple tokens per decoding step, while preserving the same Python API interface.
-For more details, please refer to the [generation script](https://github.com/tile-ai/TileRT/blob/main/python/generate.py).
+For the full list of CLI flags (sampling, batching, benchmark modes, …), run `python -m tilert.generate --help`.
## Status & Future Work
diff --git a/assets/glm5-mtp.png b/assets/glm5-mtp.png
deleted file mode 100644
index d9ebb32..0000000
Binary files a/assets/glm5-mtp.png and /dev/null differ
diff --git a/assets/glm5-without-mtp.png b/assets/glm5-without-mtp.png
deleted file mode 100644
index 28ecf08..0000000
Binary files a/assets/glm5-without-mtp.png and /dev/null differ
diff --git a/assets/glm5_tilert_mtp.png b/assets/glm5_tilert_mtp.png
new file mode 100644
index 0000000..c4db83f
Binary files /dev/null and b/assets/glm5_tilert_mtp.png differ
diff --git a/pyproject.toml b/pyproject.toml
index 4012628..13aa491 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -14,7 +14,16 @@ classifiers = [
]
dependencies = [
- "torch",
+ # Pinned to the exact ABI the v0.1.4 wheel was built against. ``torch`` must
+ # come from PyTorch's cu130 index (``--index-url
+ # https://download.pytorch.org/whl/cu130``); installing from PyPI yields a
+ # CUDA build that does not match the cu130-linked tilert binary.
+ "torch==2.11.0",
+ "transformers==4.46.3",
+ "tokenizers==0.20.3",
+ "numpy",
+ "scipy",
+ "einops",
]
[project.optional-dependencies]
@@ -55,18 +64,10 @@ dev = [
Homepage = "https://github.com/tile-ai/TileRT"
Issues = "https://github.com/tile-ai/TileRT/issues"
-[build-system]
-requires = [
- "cmake>=3.25.0", # required by c++ 20
- "packaging",
- "setuptools>=64.0.0",
- "setuptools-scm>=8.0",
- "wheel",
- "pytest-runner",
- "pytest",
- "torch"
-]
-build-backend = "setuptools.build_meta"
+# Note: this repository ships the public sources that match the v0.1.4 wheel.
+# The wheel itself is built in the development repo (TileRT-dev/TileRT) with
+# scikit-build-core; no [build-system] block is declared here on purpose so
+# nobody accidentally runs ``pip wheel .`` against this presentation copy.
[tool.black]
line-length = 100
diff --git a/python/__init__.py b/python/__init__.py
deleted file mode 100644
index dbda493..0000000
--- a/python/__init__.py
+++ /dev/null
@@ -1,62 +0,0 @@
-"""TileRT Python package."""
-
-import ctypes
-import logging
-from pathlib import Path
-from typing import Any
-
-import torch
-
-if not hasattr(torch, "ops"):
- raise RuntimeError("PyTorch is required but torch.ops is not available")
-
-from .__version__ import __version__
-
-
-def init_logging() -> logging.Logger:
- """Initialize logging configuration."""
- logging.basicConfig(
- level=logging.DEBUG,
- format="%(filename)s:%(lineno)d [%(levelname)s]: %(message)s",
- )
- return logging.getLogger(__name__)
-
-
-logger = init_logging()
-
-
-def _load_library(filename: str) -> Any:
- """Load the C++ library.
-
- Args:
- filename: Name of the library file.
-
- Returns:
- Any: The loaded library.
-
- Raises:
- RuntimeError: If the library cannot be loaded.
- """
- lib_path = Path(__file__).parent / filename
-
- try:
- torch.ops.load_library(str(lib_path))
- return lib_path
- except Exception as e:
- raise RuntimeError(f"Failed to load library from {lib_path}") from e
-
-
-_load_library("libtilert.so")
-
-
-from . import models # noqa: E402
-from .models import deepseek_v3_2 # noqa: E402
-from .tilert_init import tilert_init # noqa: E402
-
-__all__ = [
- "logger",
- "tilert_init",
- "models",
- "deepseek_v3_2",
- "__version__",
-]
diff --git a/python/__version__.py b/python/__version__.py
deleted file mode 100644
index 08768ce..0000000
--- a/python/__version__.py
+++ /dev/null
@@ -1,20 +0,0 @@
-from __future__ import annotations
-
-from importlib.metadata import PackageNotFoundError
-from importlib.metadata import version as pkg_version
-from pathlib import Path
-
-try:
- __version__ = pkg_version("tilert")
-except PackageNotFoundError:
- try:
- from setuptools_scm import get_version
-
- __version__ = get_version(
- root=str(Path(__file__).resolve().parents[1]),
- relative_to=__file__,
- )
- except Exception:
- __version__ = "0.0.0"
-
-__all__ = ["__version__"]
diff --git a/python/benchmark/long_prompt.py b/python/benchmark/long_prompt.py
deleted file mode 100644
index f6d4d0e..0000000
--- a/python/benchmark/long_prompt.py
+++ /dev/null
@@ -1,46 +0,0 @@
-"""Long-prompt benchmark: single generation, measures long-form throughput."""
-
-from typing import cast
-
-import numpy as np
-from benchmark import BenchMode, BenchStats, CellStats, Generator, apply_mode
-
-PROMPT = "Hi, can you tell me a very long story, with roughly 3000 words?"
-
-
-def run(generator: Generator, modes: list[BenchMode]) -> BenchStats:
- """Run the long-prompt benchmark for each mode.
-
- Returns stats with column: Long.
- """
- stats: BenchStats = {}
-
- for mode in modes:
- apply_mode(generator, mode)
- print(f"\n--- Long-prompt benchmark ({mode.label}) ---")
- print(f"Prompt: {PROMPT}")
- print("Completion:")
-
- _, time_list, accepted_counts = cast(
- tuple[str, list[float], list[int]],
- generator.generate(PROMPT, True, with_mtp=mode.with_mtp),
- )
-
- mode_stats: dict[str, CellStats] = {}
-
- if mode.with_mtp and accepted_counts:
- total_tokens = sum(accepted_counts)
- total_time = sum(time_list)
- speed = total_tokens / total_time if total_time > 0 else 0
- avg_ms = total_time / len(time_list) * 1000
- avg_a = total_tokens / len(accepted_counts)
- acc_rate = f"{avg_a:.2f}/{min(accepted_counts)}/{max(accepted_counts)}"
- mode_stats["Long"] = CellStats(tok_s=speed, ms=avg_ms, acc_rate=acc_rate)
- elif time_list:
- mean_time = float(np.mean(time_list))
- speed = 1 / mean_time
- mode_stats["Long"] = CellStats(tok_s=speed, ms=mean_time * 1000)
-
- stats[mode.label] = mode_stats
-
- return stats
diff --git a/python/generate.py b/python/generate.py
deleted file mode 100644
index 5724e8e..0000000
--- a/python/generate.py
+++ /dev/null
@@ -1,192 +0,0 @@
-"""Text generation script for TileRT."""
-
-from argparse import ArgumentParser
-
-from benchmark import BenchMode
-from benchmark import coding_prompt as coding_bench
-from benchmark import long_prompt as long_bench
-from benchmark import merge_stats, print_summary_table
-from benchmark import short_prompt as short_bench
-
-from tilert.models.deepseek_v3_2.generator import DSAv32Generator
-from tilert.models.deepseek_v3_2.model_args import ModelArgs as DSAv32ModelArgs
-from tilert.models.glm_5.generator import GLM5Generator
-from tilert.models.glm_5.model_args import ModelArgsGLM5
-
-
-def get_generator(
- model_type: str,
- max_new_tokens: int,
- temperature: float,
- model_weights_dir: str,
- with_mtp: bool,
- top_p: float = 0.9,
- top_k: int = 256,
- enable_thinking: bool = False,
- sampling_seed: int = 42,
-) -> DSAv32Generator | GLM5Generator:
- """Get the appropriate generator based on model type."""
- assert model_type in ["deepseek_v3_2", "glm5"]
- if model_type == "deepseek_v3_2":
- model_args = DSAv32ModelArgs()
- return DSAv32Generator(
- model_args=model_args,
- max_new_tokens=max_new_tokens,
- temperature=temperature,
- model_weights_dir=model_weights_dir,
- with_mtp=with_mtp,
- top_p=top_p,
- top_k=top_k,
- use_topp=top_p < 1.0,
- sampling_seed=sampling_seed,
- )
- model_args = ModelArgsGLM5()
- return GLM5Generator(
- model_args=model_args,
- max_new_tokens=max_new_tokens,
- temperature=temperature,
- model_weights_dir=model_weights_dir,
- with_mtp=with_mtp,
- top_p=top_p,
- top_k=top_k,
- use_topp=top_p < 1.0,
- enable_thinking=enable_thinking,
- sampling_seed=sampling_seed,
- )
-
-
-def parse_args(): # type: ignore
- parser = ArgumentParser(description="Command-line interface for text generation.")
- parser.add_argument(
- "--model-weights-dir",
- type=str,
- required=True,
- help="Path to model weights directory",
- )
- parser.add_argument(
- "--model",
- type=str,
- default="deepseek_v3_2",
- choices=["deepseek_v3_2", "glm5"],
- help="Model type to use (default: deepseek_v3_2)",
- )
- parser.add_argument("--max-new-tokens", type=int, default=4000, help="Max tokens to generate")
- parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature")
- parser.add_argument(
- "--top-p",
- type=float,
- default=1.0,
- help="Top-p (nucleus) sampling threshold. Use < 1.0 to enable top-p sampling (e.g. 0.9)",
- )
- parser.add_argument("--top-k", type=int, default=256, help="Top-k sampling threshold")
- parser.add_argument("--interactive", action="store_true")
- parser.add_argument(
- "--with-mtp",
- action="store_true",
- help="Enable MTP (Multi-Token Prediction) for speculative decoding",
- )
- parser.add_argument(
- "--use-random-weights",
- action="store_true",
- help="Use random weights instead of pretrained (for testing MTP without real weights)",
- )
- parser.add_argument(
- "--enable-thinking",
- action="store_true",
- help="Enable thinking mode in chat template",
- )
- parser.add_argument(
- "--sampling-seed",
- type=int,
- default=42,
- help="Sampling seed for top-p sampling (fixed per request, default: 42)",
- )
- return parser.parse_args()
-
-
-if __name__ == "__main__":
- """
- usage:
- execute below command under tilert root directory:
-
- # DeepSeek V3.2 - Standard generation with pretrained weights:
- python python/generate.py --model-weights-dir "xxxx" 2>&1 | tee test.log
-
- # DeepSeek V3.2 - MTP generation with random weights (for testing):
- python python/generate.py --model-weights-dir "xxxx" --with-mtp \
- --use-random-weights 2>&1 | tee test.log
-
- # DeepSeek V3.2 - MTP generation with pretrained weights (when available):
- python python/generate.py --model-weights-dir "xxxx" --with-mtp 2>&1 | tee test.log
-
- # GLM5 - Standard generation with random weights (for testing):
- python python/generate.py --model glm5 --model-weights-dir "xxxx" \
- --use-random-weights 2>&1 | tee test.log
-
- # GLM5 - Standard generation with pretrained weights:
- python python/generate.py --model glm5 --model-weights-dir "xxxx" 2>&1 | tee test.log
-
- # GLM5 - MTP generation with random weights (for testing):
- python python/generate.py --model glm5 --model-weights-dir "xxxx" --with-mtp \
- --use-random-weights 2>&1 | tee test.log
-
- # GLM5 - MTP generation with pretrained weights:
- python python/generate.py --model glm5 --model-weights-dir "xxxx" --with-mtp \
- 2>&1 | tee test.log
- """
- args = parse_args()
-
- generator = get_generator(
- model_type=args.model,
- max_new_tokens=args.max_new_tokens,
- temperature=args.temperature,
- model_weights_dir=args.model_weights_dir,
- with_mtp=args.with_mtp if args.interactive else True,
- top_p=args.top_p,
- top_k=args.top_k,
- enable_thinking=args.enable_thinking,
- sampling_seed=args.sampling_seed,
- )
-
- print("Loading pretrained weights...")
- generator.from_pretrained()
-
- # simple memoryless interactive mode
- if args.interactive:
- print("Welcome to the TileRT interactive mode! Type '/exit' to exit.")
- while True:
- prompt = input(">>> ")
- if prompt == "/exit":
- break
- _ = generator.generate(prompt) # type: ignore[has-type]
- else:
-
- # 3 modes: top-k1 w/o MTP, top-k1 w/ MTP, top-p0.95 w/ MTP
- modes = [
- BenchMode(with_mtp=False, label="top-k1 w/o MTP"),
- BenchMode(with_mtp=True, label="top-k1 w/ MTP"),
- BenchMode(
- with_mtp=True,
- label="top-p0.95 w/ MTP",
- use_topp=True,
- top_p=0.95,
- top_k=args.top_k,
- temperature=args.temperature,
- ),
- ]
-
- # Run all benchmarks and collect stats
- all_bench_stats = [
- short_bench.run(generator, modes),
- coding_bench.run(generator, modes),
- long_bench.run(generator, modes),
- ]
-
- # Print unified summary table
- print_summary_table(
- merge_stats(all_bench_stats),
- model_name=args.model.upper(),
- )
-
- print("Cleaning up...")
- generator.cleanup()
diff --git a/python/models/common.py b/python/models/common.py
deleted file mode 100644
index b213793..0000000
--- a/python/models/common.py
+++ /dev/null
@@ -1,361 +0,0 @@
-from typing import cast
-
-import torch
-import torch.distributed as dist
-import torch.nn as nn
-import torch.nn.functional as F
-
-__all__ = [
- "init_func",
- "linear",
- "Linear",
- "RMSNorm",
- "LayerNorm",
- "ColumnParallelLinear",
- "RowParallelLinear",
- "ParallelEmbedding",
-]
-
-from tilert.models.deepseek_config import (
- block_size,
- gemm_impl,
- get_rank,
- get_world_size,
- is_distributed,
-)
-from tilert.models.deepseek_v3_2.refs.kernel import act_quant, fp8_gemm, weight_dequant
-
-
-def _get_scale_tensor(tensor: torch.Tensor) -> torch.Tensor:
- """Return the dynamically attached ``scale`` tensor."""
- scale = getattr(tensor, "scale", None)
- if scale is None:
- raise AttributeError("Expected quantized tensor to carry a 'scale' attribute.")
- return cast(torch.Tensor, scale)
-
-
-def init_func(x_in: torch.Tensor) -> torch.Tensor:
- x_dtype = x_in.dtype
- x_fp32 = x_in.to(torch.float32)
- if x_fp32.dim() >= 2:
- initial_tensor = nn.init.kaiming_uniform_(x_fp32)
- else:
- initial_tensor = nn.init.uniform_(x_fp32)
- return initial_tensor.to(x_dtype)
-
-
-def linear(
- x_in: torch.Tensor,
- weight: torch.Tensor,
- bias: torch.Tensor | None = None,
- scale_fmt: str | None = None,
-) -> torch.Tensor:
- """
- Applies a linear transformation to the incoming data: y = xA^T + b.
-
- This function supports specialized implementations based on quantization
- and tensor formats.
-
- Args:
- x_in (torch.Tensor): The input tensor.
- weight (torch.Tensor): The weight tensor. It may be quantized and
- requires dequantization for certain cases.
- bias (Optional[torch.Tensor]): The bias tensor to be added. Default is None.
-
- Returns:
- torch.Tensor: The result of the linear transformation, which may involve
- quantization-aware computations depending on the input parameters.
-
- Notes:
- - If `weight` is quantized (e.g., `element_size() == 1`), a dequantized version is used
- for computation.
- - If `gemm_impl == "bf16"`, dequantization and a `bf16` GEMM operation are applied.
- - For other cases, the function applies quantization to `x` and uses `fp8_gemm`
- for computation.
- """
- if weight.element_size() > 1:
- return F.linear(x_in, weight, bias)
- if gemm_impl == "bf16":
- weight = weight_dequant(weight, _get_scale_tensor(weight))
- return F.linear(x_in, weight, bias)
-
- x_quant: torch.Tensor
- scale: torch.Tensor
- x_quant, scale = act_quant(x_in, block_size, scale_fmt)
- y_out: torch.Tensor = fp8_gemm(x_quant, scale, weight, _get_scale_tensor(weight))
- if bias is not None:
- y_out += bias
- return y_out
-
-
-class Linear(nn.Module):
- """
- Custom linear layer with support for quantized weights and optional bias.
-
- Args:
- in_features (int): Number of input features.
- out_features (int): Number of output features.
- bias (bool): Whether to include a bias term. Defaults to False.
- dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
- """
-
- dtype = torch.bfloat16
- scale_fmt: str | None = None
-
- def __init__(
- self,
- in_features: int,
- out_features: int,
- bias: bool = False,
- dtype: torch.dtype | None = None,
- weight: torch.Tensor | None = None,
- ):
- super().__init__()
- self.in_features = in_features
- self.out_features = out_features
-
- if weight is not None:
- self.weight = torch.nn.Parameter(weight)
- else:
- self.weight = nn.Parameter(
- init_func(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype))
- )
-
- if self.weight.element_size() == 1:
- scale_out_features = (out_features + block_size - 1) // block_size
- scale_in_features = (in_features + block_size - 1) // block_size
- scale_param = nn.Parameter(
- init_func(
- torch.empty(
- scale_out_features,
- scale_in_features,
- dtype=torch.float32,
- )
- )
- )
- self.scale = scale_param
- self.weight.scale = scale_param # type: ignore[attr-defined]
- else:
- self.register_parameter("scale", None)
-
- if bias:
- self.bias = nn.Parameter(init_func(torch.empty(out_features)))
- else:
- self.register_parameter("bias", None)
-
- def forward(self, x_in: torch.Tensor) -> torch.Tensor:
- """
- Forward pass for the custom linear layer.
-
- Args:
- x (torch.Tensor): Input tensor.
-
- Returns:
- torch.Tensor: Transformed tensor after linear computation.
- """
- return linear(x_in, self.weight, self.bias, self.scale_fmt)
-
-
-class RMSNorm(nn.Module):
- """
- Root Mean Square Layer Normalization (RMSNorm).
-
- Args:
- dim (int): Dimension of the input tensor.
- eps (float): Epsilon value for numerical stability. Defaults to 1e-6.
- """
-
- def __init__(self, dim: int, eps: float = 1e-6, weight: torch.Tensor | None = None):
- super().__init__()
- self.dim = dim
- self.eps = eps
-
- if weight is None:
- self.weight = nn.Parameter(init_func(torch.empty(dim, dtype=torch.float32)))
- else:
- self.weight = torch.nn.Parameter(weight)
-
- def forward(
- self, x: torch.Tensor, residual: torch.Tensor | None = None
- ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- """
- Forward pass for RMSNorm.
-
- Args:
- x (torch.Tensor): Input tensor.
-
- Returns:
- torch.Tensor: Normalized tensor with the same shape as input.
- """
- dtype = torch.bfloat16 # x.dtype
- if residual is None:
- x = x.float()
- var_s = x.pow(2).mean(-1, keepdim=True)
- x = x * torch.rsqrt(var_s + self.eps)
- return (self.weight * x).to(dtype)
-
- x = residual = x.float() + residual.float()
- var_s = x.pow(2).mean(-1, keepdim=True)
- x = x * torch.rsqrt(var_s + self.eps)
- return (self.weight * x).to(dtype), residual.to(dtype)
-
-
-class LayerNorm(nn.Module):
- """Layer Normalization."""
-
- def __init__(self, dim: int, eps: float = 1e-6):
- super().__init__()
- self.dim = dim
- self.eps = eps
- self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
- self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return F.layer_norm(x.float(), (self.dim,), self.weight, self.bias, self.eps).type_as(x)
-
-
-class ColumnParallelLinear(Linear):
- """
- Column parallel linear layer.
-
- Linear layer with column parallelism, splitting output features across
- distributed processes.
-
- Args:
- in_features (int): Number of input features.
- out_features (int): Total number of output features.
- bias (bool): Whether to include a bias term. Defaults to False.
- dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
- """
-
- def __init__(
- self,
- in_features: int,
- out_features: int,
- bias: bool = False,
- dtype: torch.dtype | None = None,
- ):
- world_size = get_world_size()
- assert (
- out_features % world_size == 0
- ), f"Output features must be divisible by world size {world_size}"
- self.part_out_features = out_features // world_size
- super().__init__(in_features, self.part_out_features, bias, dtype)
-
- def forward(self, x_in: torch.Tensor) -> torch.Tensor:
- """
- Forward pass for column parallel linear layer.
-
- Args:
- x (torch.Tensor): Input tensor.
-
- Returns:
- torch.Tensor: Transformed tensor with column-parallel computation.
- """
- return linear(x_in, self.weight, self.bias)
-
-
-class RowParallelLinear(Linear):
- """
- Linear layer with row parallelism, splitting input features across distributed processes.
-
- Args:
- in_features (int): Total number of input features.
- out_features (int): Number of output features.
- bias (bool): Whether to include a bias term. Defaults to False.
- dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
- """
-
- def __init__(
- self,
- in_features: int,
- out_features: int,
- bias: bool = False,
- reduce_output: bool = True,
- dtype: torch.dtype | None = None,
- ):
-
- self.world_size = get_world_size()
-
- if in_features % self.world_size != 0:
- raise ValueError(
- f"Input features must be divisible by world size (world_size={self.world_size})"
- )
-
- self.part_in_features = in_features // self.world_size
- self.reduce_output = reduce_output
-
- super().__init__(self.part_in_features, out_features, bias, dtype)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """
- Forward pass for row parallel linear layer.
-
- Args:
- x (torch.Tensor): Input tensor.
-
- Returns:
- torch.Tensor: Transformed tensor with row-parallel computation.
- """
- y = linear(x, self.weight, None, self.scale_fmt)
- if self.reduce_output and is_distributed() > 1:
- y = y.float()
- dist.all_reduce(y)
- if self.bias is not None:
- y += self.bias
- return y.type_as(x)
-
-
-class ParallelEmbedding(nn.Module):
- """
- Parallel embedding layer.
-
- Embedding layer with parallelism support across distributed processes.
-
- Args:
- vocab_size (int): Vocabulary size.
- dim (int): Embedding dimension.
- """
-
- def __init__(self, vocab_size: int, dim: int):
- super().__init__()
- self.vocab_size = vocab_size
- self.dim = dim
-
- self.world_size = get_world_size()
- self.rank = get_rank()
-
- assert (
- vocab_size % self.world_size == 0
- ), f"Vocabulary size must be divisible by world size {self.world_size}"
-
- self.part_vocab_size = vocab_size // self.world_size
- self.vocab_start_idx = self.rank * self.part_vocab_size
- self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
-
- self.weight = nn.Parameter(init_func(torch.empty(self.part_vocab_size, self.dim)))
-
- def forward(self, x_in: torch.Tensor) -> torch.Tensor:
- """
- Forward pass for parallel embedding layer.
-
- Args:
- x (torch.Tensor): Input tensor containing token indices.
-
- Returns:
- torch.Tensor: Embedded representations.
-
- Raises:
- ValueError: If `world_size` is not defined.
- """
- if self.world_size > 1:
- mask = (x_in < self.vocab_start_idx) | (x_in >= self.vocab_end_idx)
- x_in = x_in - self.vocab_start_idx
- x_in[mask] = 0
-
- y_out = F.embedding(x_in, self.weight)
-
- if is_distributed():
- y_out[mask] = 0
- dist.all_reduce(y_out)
- return y_out
diff --git a/python/models/deepseek_config.py b/python/models/deepseek_config.py
deleted file mode 100644
index 6c7f5da..0000000
--- a/python/models/deepseek_config.py
+++ /dev/null
@@ -1,66 +0,0 @@
-"""Global configuration for DeepSeek models."""
-
-import os
-from typing import Literal
-
-import torch
-import torch.distributed as dist
-
-__all__ = [
- "get_world_size",
- "get_rank",
- "block_size",
- "gemm_impl",
- "attn_impl",
-]
-
-
-def is_distributed() -> bool:
- return bool(dist.is_initialized())
-
-
-def get_world_size() -> int:
- # NOTE: default world size is 8, since tilert kernels implemented for 8 GPUs.
- # DO NOT modify this value unless you know how much it affects the tilert kernels.
- return dist.get_world_size() if dist.is_initialized() else 8
-
-
-def get_rank() -> int:
- return dist.get_rank() if dist.is_initialized() else 0
-
-
-def init_distributed_training() -> tuple[int, int, bool]:
- """Initialize distributed training."""
- if "LOCAL_RANK" in os.environ:
- local_rank = int(os.environ["LOCAL_RANK"])
- world_rank = int(os.environ["RANK"])
- world_size = int(os.environ["WORLD_SIZE"])
- is_distributed = True
- else:
- local_rank = 0
- world_rank = 0
- world_size = 1
- is_distributed = False
-
- torch.cuda.set_device(local_rank)
- torch.set_default_device(f"cuda:{local_rank}")
- torch.set_default_dtype(torch.bfloat16)
-
- if world_size > 1:
- dist.init_process_group(
- backend="nccl",
- world_size=world_size,
- rank=world_rank,
- init_method="env://",
- device_id=local_rank,
- )
-
- rank = get_rank()
- world_size = get_world_size()
-
- return rank, world_size, is_distributed
-
-
-block_size = 128
-gemm_impl: Literal["bf16", "fp8"] = "bf16"
-attn_impl: Literal["naive", "absorb"] = "absorb"
diff --git a/python/models/deepseek_v3_2/modules/mla.py b/python/models/deepseek_v3_2/modules/mla.py
deleted file mode 100644
index 6f1d138..0000000
--- a/python/models/deepseek_v3_2/modules/mla.py
+++ /dev/null
@@ -1,107 +0,0 @@
-import torch
-
-from tilert.models.base import SerializableTileRTModule
-from tilert.models.deepseek_v3_2.model_args import ModelArgs
-from tilert.models.deepseek_v3_2.ops.layernorm_rope_rotate import LayerNormRoPERotate
-from tilert.models.deepseek_v3_2.ops.projo_wkvb import ProjoWKVb
-from tilert.models.deepseek_v3_2.ops.projq_wqb import ProjqWqb
-from tilert.models.deepseek_v3_2.ops.projx_wis import ProjxWis
-from tilert.models.deepseek_v3_2.ops.rmsnorm_kv import KVRMSNorm
-from tilert.models.deepseek_v3_2.ops.rmsnorm_projq_wqib import (
- RmsnormProjqWqib,
- RmsnormProjqWqibAlgorithm,
-)
-from tilert.models.deepseek_v3_2.ops.rmsnorm_projx_wqkvia import (
- RMSNormProjxWqkvia,
- RMSNormProjxWqkviaAlgorithm,
-)
-from tilert.models.deepseek_v3_2.ops.unproj_o_allreduce import (
- UnProjOAllReduce,
- UnProjOAllReduceAlgorithm,
-)
-
-
-class Mla(SerializableTileRTModule):
- """Implement the MLA operations."""
-
- def __init__(self, model_args: ModelArgs, device_id: int, num_devices: int):
- super().__init__(model_args=model_args, device_id=device_id, num_devices=num_devices)
-
- self.rmsnorm_projx_wqkvia = RMSNormProjxWqkvia(
- model_args=model_args, device_id=device_id, num_devices=num_devices
- )
- if model_args.arch_name == "glm_5":
- self.rmsnorm_projx_wqkvia.algorithm = RMSNormProjxWqkviaAlgorithm.DECOUPLED
- else:
- self.rmsnorm_projx_wqkvia.algorithm = RMSNormProjxWqkviaAlgorithm.GENERAL
- self.register_op(self.rmsnorm_projx_wqkvia)
-
- self.layernorm_rope_rotate = LayerNormRoPERotate(
- model_args=model_args, device_id=device_id, num_devices=num_devices
- )
- self.register_op(self.layernorm_rope_rotate)
-
- self.rmsnorm_projq_wqib = RmsnormProjqWqib(
- model_args=model_args, device_id=device_id, num_devices=num_devices
- )
- if model_args.arch_name == "glm_5":
- self.rmsnorm_projq_wqib.algorithm = RmsnormProjqWqibAlgorithm.FP16MMA
- else:
- self.rmsnorm_projq_wqib.algorithm = RmsnormProjqWqibAlgorithm.BF16
- self.register_op(self.rmsnorm_projq_wqib)
-
- self.projx_wis = ProjxWis(
- model_args=model_args, device_id=device_id, num_devices=num_devices
- )
- self.register_op(self.projx_wis)
-
- self.projq_wqb = ProjqWqb(
- model_args=model_args, device_id=device_id, num_devices=num_devices
- )
- self.register_op(self.projq_wqb)
-
- self.rmsnorm_kv = KVRMSNorm(
- model_args=model_args, device_id=device_id, num_devices=num_devices
- )
- self.register_op(self.rmsnorm_kv)
-
- self.projo_wkvb = ProjoWKVb(
- model_args=model_args, device_id=device_id, num_devices=num_devices
- )
- self.register_op(self.projo_wkvb)
-
- self.unproj_o_allreduce = UnProjOAllReduce(
- model_args=model_args,
- device_id=device_id,
- num_devices=num_devices,
- algorithm=UnProjOAllReduceAlgorithm.FP8MMA,
- )
-
- if model_args.arch_name == "glm_5":
- self.unproj_o_allreduce.algorithm = UnProjOAllReduceAlgorithm.FP16MMA
-
- self.register_op(self.unproj_o_allreduce)
-
- self.kv_cache: torch.Tensor | None = None
- self.pe_cache: torch.Tensor | None = None
- self.ki_cache: torch.Tensor | None = None
-
- def get_cache_vars(self) -> list[torch.Tensor]:
- cache_seq_len = self.model_args.max_seq_len + self.model_args.kv_cache_pad
- bs_args = (self.model_args.max_batch_size, cache_seq_len)
- if self.kv_cache is None:
- kv_dim = self.model_args.kv_lora_rank
- self.kv_cache = torch.zeros(
- *bs_args, kv_dim, dtype=torch.bfloat16, device=f"cuda:{self.device_id}"
- )
- if self.pe_cache is None:
- pe_dim = self.model_args.qk_rope_head_dim
- self.pe_cache = torch.zeros(
- *bs_args, pe_dim, dtype=torch.bfloat16, device=f"cuda:{self.device_id}"
- )
- if self.ki_cache is None:
- ki_dim = self.model_args.index_head_dim
- self.ki_cache = torch.zeros(
- *bs_args, ki_dim, dtype=torch.bfloat16, device=f"cuda:{self.device_id}"
- )
- return [*super().get_cache_vars(), self.ki_cache, self.kv_cache, self.pe_cache]
diff --git a/python/models/deepseek_v3_2/ops/__init__.py b/python/models/deepseek_v3_2/ops/__init__.py
deleted file mode 100644
index e832905..0000000
--- a/python/models/deepseek_v3_2/ops/__init__.py
+++ /dev/null
@@ -1,109 +0,0 @@
-"""Core operations for deepseek v3.2."""
-
-from tilert.models.deepseek_v3_2.ops.down_allreduce import (
- DownAllReduce,
- down_allreduce,
- down_allreduce_glm5,
-)
-from tilert.models.deepseek_v3_2.ops.eh_proj_allreduce import EHProjAllReduce, eh_proj_allreduce
-from tilert.models.deepseek_v3_2.ops.expert_down_allreduce import (
- ExpertDownAllReduce,
- expert_down_allreduce,
-)
-from tilert.models.deepseek_v3_2.ops.expert_sel_up_gate_silu import (
- ExpertSelectUpGateSiLU,
- ExpertSelectUpGateSiLUAlgorithm,
-)
-from tilert.models.deepseek_v3_2.ops.expert_select import expert_select
-from tilert.models.deepseek_v3_2.ops.flash_sparse_mla import flash_sparse_mla
-from tilert.models.deepseek_v3_2.ops.layernorm_rope_rotate import layernorm_rope_rotate
-from tilert.models.deepseek_v3_2.ops.projo_wkvb import projo_wkvb
-from tilert.models.deepseek_v3_2.ops.projq_wqb import projq_wqb
-from tilert.models.deepseek_v3_2.ops.projx_wis import projx_wis
-from tilert.models.deepseek_v3_2.ops.qkv_rope import (
- QKVRoPE,
- QKVRoPERefWeightsAlias,
- QKVRoPETilertWeightsAlias,
- qkv_rope,
-)
-from tilert.models.deepseek_v3_2.ops.rmsnorm_expert_proj import RMSNormExpertProj
-from tilert.models.deepseek_v3_2.ops.rmsnorm_head_proj import RMSNormHeadProj
-from tilert.models.deepseek_v3_2.ops.rmsnorm_kv import rmsnorm_kv
-from tilert.models.deepseek_v3_2.ops.rmsnorm_projq_wqib import (
- RmsnormProjqWqib,
- RmsnormProjqWqibAlgorithm,
- RmsnormProjqWqibWeightsConverter,
-)
-from tilert.models.deepseek_v3_2.ops.rmsnorm_projx_wqkvia import (
- RMSNormProjxWqkvia,
- RMSNormProjxWqkviaAlgorithm,
- projx_wqkvia,
- rmsnorm_projx_wqkvia,
-)
-from tilert.models.deepseek_v3_2.ops.rmsnorm_quant import rmsnorm_quant
-from tilert.models.deepseek_v3_2.ops.rmsnorm_up_gate_silu import (
- RMSNormUpGateSiLU,
- RMSNormUpGateSiLUAlgorithm,
-)
-from tilert.models.deepseek_v3_2.ops.rotate import (
- Rotate,
- RotateRefWeightsAlias,
- RotateTilertWeightsAlias,
- rotate,
- rotate_activation,
-)
-from tilert.models.deepseek_v3_2.ops.sparse_index import sparse_index, sparse_index_topk
-from tilert.models.deepseek_v3_2.ops.topk import TopK, topk_accurate, topk_approximate
-from tilert.models.deepseek_v3_2.ops.unproj_o_allreduce import (
- UnProjOAllReduce,
- UnProjOAllReduceAlgorithm,
- unproj_o_allreduce,
-)
-from tilert.models.deepseek_v3_2.ops.up_gate_silu import up_gate_silu
-
-__all__ = [
- "down_allreduce",
- "down_allreduce_glm5",
- "DownAllReduce",
- "expert_down_allreduce",
- "ExpertDownAllReduce",
- "expert_select",
- "up_gate_silu",
- "rmsnorm_projx_wqkvia",
- "projx_wqkvia",
- "rmsnorm_kv",
- "unproj_o_allreduce",
- "projo_wkvb",
- "projq_wqb",
- "rotate",
- "rotate_activation",
- "Rotate",
- "RotateRefWeightsAlias",
- "RotateTilertWeightsAlias",
- "layernorm_rope_rotate",
- "TopK",
- "topk_approximate",
- "topk_accurate",
- "sparse_index",
- "sparse_index_topk",
- "flash_sparse_mla",
- "projx_wis",
- "qkv_rope",
- "QKVRoPE",
- "QKVRoPERefWeightsAlias",
- "QKVRoPETilertWeightsAlias",
- "eh_proj_allreduce",
- "rmsnorm_quant",
- "RmsnormProjqWqib",
- "RmsnormProjqWqibAlgorithm",
- "RmsnormProjqWqibWeightsConverter",
- "RMSNormExpertProj",
- "RMSNormProjxWqkvia",
- "RMSNormProjxWqkviaAlgorithm",
- "RMSNormUpGateSiLU",
- "UnProjOAllReduce",
- "UnProjOAllReduceAlgorithm",
- "RMSNormHeadProj",
- "ExpertSelectUpGateSiLU",
- "ExpertSelectUpGateSiLUAlgorithm",
-]
diff --git a/python/models/deepseek_v3_2/ops/expert_select.py b/python/models/deepseek_v3_2/ops/expert_select.py
deleted file mode 100644
index 6a16d76..0000000
--- a/python/models/deepseek_v3_2/ops/expert_select.py
+++ /dev/null
@@ -1,49 +0,0 @@
-"""ExpertSelect operation module."""
-
-import torch
-
-__all__ = [
- "expert_select",
- "expert_select_one_stage",
-]
-
-
-def expert_select(
- scores_in: torch.Tensor,
- bias_in: torch.Tensor,
- expert_probs_out: torch.Tensor,
- expert_indices_out: torch.Tensor,
- profile_logs: torch.Tensor,
-) -> None:
- """
- Expert Select operation.
-
- Original two-stage expert select operation used in DeepSeek V3.2.
- """
- torch.ops.tilert.expert_select_op(
- scores_in,
- bias_in,
- expert_probs_out,
- expert_indices_out,
- profile_logs,
- )
-
-
-def expert_select_one_stage(
- scores_in: torch.Tensor,
- bias_in: torch.Tensor,
- expert_probs_out: torch.Tensor,
- expert_indices_out: torch.Tensor,
- profile_logs: torch.Tensor,
-) -> None:
- """Expert Select operation.
-
- Modified one-stage expert select operation used in Kimi and GLM.
- """
- torch.ops.tilert.expert_select_glm5_op(
- scores_in,
- bias_in,
- expert_probs_out,
- expert_indices_out,
- profile_logs,
- )
diff --git a/python/models/deepseek_v3_2/ops/head_proj.py b/python/models/deepseek_v3_2/ops/head_proj.py
deleted file mode 100644
index 5ab8b77..0000000
--- a/python/models/deepseek_v3_2/ops/head_proj.py
+++ /dev/null
@@ -1,22 +0,0 @@
-"""HeadProj operation module."""
-
-import torch
-
-__all__ = [
- "head_proj",
-]
-
-
-def head_proj(
- hidden_in: torch.Tensor,
- weight_in: torch.Tensor,
- logits_out: torch.Tensor,
- profile_logs: torch.Tensor,
-) -> None:
- """Head Projection operation."""
- torch.ops.tilert.head_proj_op(
- hidden_in,
- weight_in,
- logits_out,
- profile_logs,
- )
diff --git a/python/models/deepseek_v3_2/ops/projo_wkvb.py b/python/models/deepseek_v3_2/ops/projo_wkvb.py
deleted file mode 100644
index 618e5ee..0000000
--- a/python/models/deepseek_v3_2/ops/projo_wkvb.py
+++ /dev/null
@@ -1,283 +0,0 @@
-"""UnprojOB operation module."""
-
-from dataclasses import dataclass
-from enum import Enum
-
-import torch
-
-from tilert.models.base import TileRTModule, TilertWeightsConverter
-from tilert.models.common import init_func, weight_dequant
-from tilert.models.deepseek_v3_2.model_args import ModelArgs
-from tilert.profiler.utils import parse_profile_log_tensor
-from tilert.utils import get_profile_log_tensor
-
-__all__ = [
- "projo_wkvb",
- "ProjoWKVb",
- "ProjoWKVbAlgorithm",
- "ProjoWKVbWeightsConverter",
- "ProjoWKVbRefWeightsAlias",
- "ProjoWKVbTilertWeightsAlias",
-]
-
-
-def projo_wkvb(
- o_in: torch.Tensor,
- wkv_b_b: torch.Tensor,
- wkv_b_scales: torch.Tensor,
- output: torch.Tensor,
- profile_logs: torch.Tensor,
-) -> None:
- """
- Define the UnprojOB operation.
-
- Args:
- o_in: Input tensor.
- wkv_b_b: Weight tensor.
- wkv_b_scales: Scale tensor.
- output: Output tensor.
- profile_logs: Profile logs tensor.
- """
- # Choose operation based on v_head_dim (128 for deepseek_v3_2, 256 for glm5)
- if output.shape[-1] == 128:
- torch.ops.tilert.projo_wkvb_op(o_in, wkv_b_b, wkv_b_scales, output, profile_logs)
- elif output.shape[-1] == 256:
- torch.ops.tilert.proj_ob_glm5_op(o_in, wkv_b_b, wkv_b_scales, output, profile_logs)
- else:
- raise ValueError(f"Unsupported v_head_dim: {output.shape[-1]}")
-
-
-class ProjoWKVbAlgorithm(Enum):
- """ProjoWKVb algorithm"""
-
- GENERAL = "general"
-
-
-class ProjoWKVbWeightsConverter(TilertWeightsConverter):
- def __init__(self, model_args: ModelArgs, num_devices: int):
- super().__init__(model_args, num_devices)
-
- def convert_to_general(self, weights: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
- with torch.inference_mode():
- tilert_wkv_b_weights, tilert_wkv_b_scales = weights
-
- # Input weights are already in the correct shape from device_sharding:
- # wkv_b_weights: (n_local_heads, v_head_dim, kv_lora_rank)
- # wkv_b_scales: (n_local_heads, v_head_dim // block_size, kv_lora_rank // block_size)
- wkv_b_b = tilert_wkv_b_weights.contiguous()
- wkv_b_b_scales = tilert_wkv_b_scales.contiguous()
- if self.model_args.arch_name == "glm_5":
- if wkv_b_b_scales.dtype != torch.float32:
- print(
- "Warning: ProjoWKVbWeightsConverter: "
- + f"wkv_b_b_scales.dtype: {wkv_b_b_scales.dtype} "
- + "is not float32, convert to float32."
- )
- wkv_b_b_scales = wkv_b_b_scales.to(torch.float32)
- else: # DS v3.2, use bfloat16 for wkv_b_b_scales
- wkv_b_b_scales = wkv_b_b_scales.to(torch.bfloat16)
-
- wkv_b_b = wkv_b_b.detach()
- wkv_b_b_scales = wkv_b_b_scales.detach()
-
- return wkv_b_b, wkv_b_b_scales
-
-
-@dataclass
-class ProjoWKVbRefWeightsAlias:
- """Reference weights alias for ProjoWKVb."""
-
- wkv_b_weights = "self_attn.kv_b_proj.weight"
- wkv_b_scales = "self_attn.kv_b_proj.weight_scale_inv"
-
- @property
- def ref_tensor_alias(self) -> list[str]:
- return [self.wkv_b_weights, self.wkv_b_scales]
-
- def __call__(self) -> list[str]:
- return self.ref_tensor_alias
-
-
-@dataclass
-class ProjoWKVbTilertWeightsAlias:
- """TileRT weights alias for ProjoWKVb."""
-
- wkv_b_weights = "wkv_b2_weights"
- wkv_b_scales = "wkv_b2_scales"
-
- @property
- def tilert_tensor_alias(self) -> list[str]:
- return [self.wkv_b_weights, self.wkv_b_scales]
-
- def __call__(self) -> list[str]:
- return self.tilert_tensor_alias
-
-
-class ProjoWKVb(TileRTModule):
- """ProjoWKVb module: O projection (wkv_b) for output."""
-
- def __init__(
- self,
- model_args: ModelArgs,
- num_devices: int,
- device_id: int = 0,
- ref_weights_alias: ProjoWKVbRefWeightsAlias | None = None,
- ):
- super().__init__(
- self.__class__.__name__,
- model_args=model_args,
- num_devices=num_devices,
- device_id=device_id,
- )
-
- self.tilert_weights_alias = ProjoWKVbTilertWeightsAlias()
- self.ref_weights_alias = (
- ref_weights_alias if ref_weights_alias is not None else ProjoWKVbRefWeightsAlias()
- )
-
- self.ref_wkv_b: torch.Tensor | None = None
- self.tilert_wkv_b_b: torch.Tensor | None = None
- self.tilert_wkv_b_b_scales: torch.Tensor | None = None
- self.output: torch.Tensor | None = None
- self.profile_logs: torch.Tensor | None = None
-
- self.num_local_heads = self.model_args.n_heads // self.num_devices
-
- # lora dim and quant block size
- self.wkvb_lora_rank = self.model_args.kv_lora_rank
- self.wkvb_lora_rank_qsize = self.wkvb_lora_rank // self.model_args.block_size
-
- self.wkvb_head_dim = self.model_args.qk_nope_head_dim + self.model_args.v_head_dim
- self.wkvb_v_head_dim = self.model_args.v_head_dim
- left_head_dim = self.wkvb_head_dim % self.model_args.block_size
- if left_head_dim != 0:
- assert self.model_args.block_size % left_head_dim == 0
- self.head_dim_block_size = left_head_dim
- self.head_dim_scale_repeat = self.model_args.block_size // self.head_dim_block_size
- else:
- self.head_dim_scale_repeat = 1
- self.head_dim_block_size = self.model_args.block_size
- self.wkvb_head_qsize = self.wkvb_head_dim // self.head_dim_block_size
- self.wkvb_v_head_qsize = self.wkvb_v_head_dim // self.head_dim_block_size
-
- def get_weights_list(self) -> list[torch.Tensor]:
- return [self.tilert_wkv_b_b, self.tilert_wkv_b_b_scales]
-
- def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
- """
- Device sharding: split weights and scales per device.
-
- Args:
- weights_map: Map from ref weight alias to tensor.
-
- Returns:
- Map from tilert weight alias to (num_devices, ...) tensors.
- """
- kv_b_proj_weight = weights_map[self.ref_weights_alias.wkv_b_weights]
- kv_b_proj_weight_scale = weights_map[self.ref_weights_alias.wkv_b_scales]
-
- dev_heads = (self.num_devices, self.num_local_heads)
- wkvb = kv_b_proj_weight.view(*dev_heads, self.wkvb_head_dim, self.wkvb_lora_rank)[
- :, :, -self.wkvb_v_head_dim :
- ]
- wkvb_scales = (
- kv_b_proj_weight_scale.view(
- self.num_devices,
- self.num_local_heads * self.wkvb_head_dim // self.model_args.block_size,
- 1,
- self.wkvb_lora_rank_qsize,
- )
- .contiguous()
- .repeat(1, 1, self.head_dim_scale_repeat, 1)
- .view(
- self.num_devices,
- self.num_local_heads,
- self.wkvb_head_qsize,
- self.wkvb_lora_rank_qsize,
- )
- .contiguous()[:, :, -self.wkvb_v_head_qsize :]
- )
- return {
- self.tilert_weights_alias.wkv_b_weights: wkvb.contiguous(),
- self.tilert_weights_alias.wkv_b_scales: wkvb_scales.contiguous(),
- }
-
- def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
- sharding_size = self.num_local_heads * self.wkvb_head_dim
- sharding_start = self.device_id * sharding_size
- sharding_end = sharding_start + sharding_size
- wkv_b = weight_dequant(
- state_dict[self.ref_weights_alias.wkv_b_weights],
- state_dict[self.ref_weights_alias.wkv_b_scales],
- )
- wkv_b = wkv_b[sharding_start:sharding_end, :]
- wkv_b = wkv_b.view(self.num_local_heads, self.wkvb_head_dim, self.wkvb_lora_rank)
- self.ref_wkv_b = wkv_b[:, -self.wkvb_v_head_dim :]
-
- def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
- self.tilert_wkv_b_b, self.tilert_wkv_b_b_scales = ProjoWKVbWeightsConverter(
- self.model_args, self.num_devices
- ).dispatch(
- ProjoWKVbAlgorithm.GENERAL,
- [
- state_dict[self.tilert_weights_alias.wkv_b_weights],
- state_dict[self.tilert_weights_alias.wkv_b_scales],
- ],
- )
-
- def init_random_weights(self) -> None:
- wkv_b = init_func(
- torch.empty(
- self.model_args.n_heads * self.wkvb_head_dim,
- self.wkvb_lora_rank,
- dtype=torch.float8_e4m3fn,
- )
- )
- wkv_b_scales = init_func(
- torch.empty(
- # Block quant should be applied to the original weight dimension (including head
- # dimension)
- self.model_args.n_heads * self.wkvb_head_dim // self.model_args.block_size,
- self.wkvb_lora_rank_qsize,
- dtype=torch.float32,
- )
- )
- ref_state_dict = dict(
- zip(
- self.ref_weights_alias(),
- [wkv_b, wkv_b_scales],
- )
- )
- self.init_reference_weights(ref_state_dict)
- sharded = self.device_sharding(ref_state_dict)
- self.init_tilert_weights({k: v[self.device_id] for k, v in sharded.items()})
-
- def init_tilert_vars(self, batch_size: int, seq_len: int) -> None:
- self.output = torch.zeros(
- (batch_size, seq_len, self.num_local_heads, self.wkvb_v_head_dim),
- dtype=torch.bfloat16,
- )
- self.profile_logs = get_profile_log_tensor()
- self.is_var_init = True
-
- def golden_forward(self, x_out: torch.Tensor) -> torch.Tensor:
- assert self.ref_wkv_b is not None
- return torch.einsum("bshc,hdc->bshd", x_out, self.ref_wkv_b)
-
- def tilert_forward(self, x_out: torch.Tensor) -> torch.Tensor:
- assert self.tilert_wkv_b_b is not None
- assert self.tilert_wkv_b_b_scales is not None
- assert self.output is not None
- assert self.profile_logs is not None
- projo_wkvb(
- x_out,
- self.tilert_wkv_b_b,
- self.tilert_wkv_b_b_scales,
- self.output,
- self.profile_logs,
- )
- if self.flag_enable_profiling_log:
- parse_profile_log_tensor(
- self.profile_logs, self.get_profile_log_path(), [(self.op_name, 0.0)]
- )
- return self.output
diff --git a/python/models/deepseek_v3_2/ops/rmsnorm_proj_top1.py b/python/models/deepseek_v3_2/ops/rmsnorm_proj_top1.py
deleted file mode 100644
index c6ec1a5..0000000
--- a/python/models/deepseek_v3_2/ops/rmsnorm_proj_top1.py
+++ /dev/null
@@ -1,29 +0,0 @@
-"""RMSNorm + head projection + top1 operation"""
-
-import torch
-
-__all__ = [
- "rmsnorm_proj_top1",
-]
-
-
-def rmsnorm_proj_top1(
- hidden_in: torch.Tensor,
- rmsnorm_gamma_in: torch.Tensor,
- head_projection_weights_in: torch.Tensor,
- token_id: torch.Tensor,
- profile_logs: torch.Tensor,
-) -> None:
- """
- Define the RMSNormProjTop1 operation.
-
- Args:
- hidden_in: Input tensor.
- rmsnorm_gamma_in: Weight tensor.
- head_projection_weights_in: Weight tensor.
- token_id: Output tensor.
- profile_logs: Profile logs tensor.
- """
- torch.ops.tilert.rmsnorm_proj_top1_op(
- hidden_in, rmsnorm_gamma_in, head_projection_weights_in, token_id, profile_logs
- )
diff --git a/python/models/deepseek_v3_2/ops/rmsnorm_projq_wqib.py b/python/models/deepseek_v3_2/ops/rmsnorm_projq_wqib.py
deleted file mode 100644
index 7adcad6..0000000
--- a/python/models/deepseek_v3_2/ops/rmsnorm_projq_wqib.py
+++ /dev/null
@@ -1,689 +0,0 @@
-"""RmsnormProjqWqib operation module."""
-
-from dataclasses import dataclass
-from enum import Enum
-
-import torch
-from einops import rearrange
-
-from tilert.models.base import TileRTModule, TilertWeightsConverter
-from tilert.models.common import weight_dequant
-from tilert.models.deepseek_v3_2.model_args import ModelArgs
-from tilert.models.deepseek_v3_2.ops.expert_sel_up_gate_silu import (
- ExpertSelectUpGateSiLUWeightsConverter as WeightsConverter,
-)
-from tilert.profiler.utils import parse_profile_log_tensor
-from tilert.utils import get_profile_log_tensor
-
-__all__ = [
- "RmsnormProjqWqib",
- "RmsnormProjqWqibAlgorithm",
- "RmsnormProjqWqibWeightsConverter",
-]
-
-
-def rmsnorm_projq_wqib_op(
- q: torch.Tensor,
- wq_b_full: torch.Tensor,
- wq_b_full_scales: torch.Tensor,
- q_norm_weight: torch.Tensor,
- q_nope: torch.Tensor,
- q_pe: torch.Tensor,
- iq: torch.Tensor,
- profile_logs: torch.Tensor,
- algorithm: str,
-) -> None:
- dim = q.shape[-1]
- if dim == 1536:
- impl_func = torch.ops.tilert.rmsnorm_proj_qb_iq_op
- elif dim == 2048:
- impl_func = torch.ops.tilert.rmsnorm_proj_qb_iq_glm5_op
- else:
- raise ValueError(f"Invalid dimension: {dim}")
- impl_func(
- q,
- wq_b_full,
- wq_b_full_scales,
- q_norm_weight,
- q_nope,
- q_pe,
- iq,
- profile_logs,
- algorithm,
- )
-
-
-class RmsnormProjqWqibAlgorithm(Enum):
- """RmsnormProjqWqib algorithm."""
-
- BF16 = "bf16"
- FP8 = "fp8"
- FP16MMA = "fp16mma"
-
-
-class RmsnormProjqWqibWeightsConverter(TilertWeightsConverter):
- """Weights converter: common format to TileRT format."""
-
- def __init__(self, model_args: ModelArgs, num_devices: int):
- super().__init__(model_args=model_args, num_devices=num_devices)
-
- self.proc_groups = 8
- self.repeat = 16
-
- self.block_size = self.model_args.block_size
- self.n_local_heads = self.model_args.n_heads // self.num_devices
-
- self.q_lora_dim = self.model_args.q_lora_rank
- self.q_lora_qdim = self.q_lora_dim // self.block_size
-
- self.qk_nope_head_dim = self.model_args.qk_nope_head_dim
- self.qk_rope_head_dim = self.model_args.qk_rope_head_dim
- self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
- self.qk_dim = self.qk_head_dim * self.n_local_heads
- self.qk_qdim = self.qk_dim // self.block_size
-
- self.index_n_heads = self.model_args.index_n_heads
- self.index_head_dim = self.index_n_heads * self.model_args.index_head_dim
- self.index_head_qdim = self.index_head_dim // self.block_size
-
- def _common_to_tilert_bf16(
- self,
- wq_b: torch.Tensor,
- wq_b_scales_raw: torch.Tensor,
- wq_b_iq: torch.Tensor,
- wq_b_iq_scales: torch.Tensor,
- rmsnorm_gamma: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """Convert common weights to TileRT BF16 layout."""
- wq_b = wq_b.reshape(self.n_local_heads, self.qk_head_dim, self.q_lora_dim)
- wq_b_nope = wq_b[:, : self.qk_nope_head_dim, :]
- wq_b_nope = wq_b_nope.reshape(
- self.n_local_heads,
- self.proc_groups,
- self.qk_nope_head_dim // self.proc_groups,
- self.q_lora_dim,
- )
- wq_b_pe = wq_b[:, self.qk_nope_head_dim :, :]
- wq_b_pe = wq_b_pe.reshape(
- self.n_local_heads,
- self.proc_groups,
- self.qk_rope_head_dim // self.proc_groups,
- self.q_lora_dim,
- )
- wq_b = torch.cat([wq_b_nope, wq_b_pe], dim=2)
- wq_b = wq_b.reshape(self.qk_dim, self.q_lora_dim)
- wq_b_full = torch.cat([wq_b, wq_b_iq], dim=0)
-
- wq_b_scales_iq_raw = wq_b_iq_scales
- wq_b_scales_t16 = (
- wq_b_scales_raw.reshape((self.qk_qdim, 1, self.q_lora_qdim))
- .repeat(1, self.repeat, 1)
- .reshape(self.qk_qdim * self.repeat, self.q_lora_qdim)
- )
- wq_b_scales_t16 = wq_b_scales_t16.reshape(
- self.n_local_heads, self.qk_head_dim // self.proc_groups, self.q_lora_qdim
- )
- wq_b_scales_t16_nope = wq_b_scales_t16[:, : self.qk_nope_head_dim // 8]
- wq_b_scales_t16_pe = wq_b_scales_t16[:, self.qk_nope_head_dim // 8 :]
- wq_b_scales_t16_nope = wq_b_scales_t16_nope.reshape(
- self.n_local_heads,
- self.proc_groups,
- self.qk_nope_head_dim // 8 // self.proc_groups,
- self.q_lora_qdim,
- )
- wq_b_scales_t16_pe = wq_b_scales_t16_pe.reshape(
- self.n_local_heads,
- self.proc_groups,
- self.qk_rope_head_dim // 8 // self.proc_groups,
- self.q_lora_qdim,
- )
- wq_b_scales_t16 = torch.cat([wq_b_scales_t16_nope, wq_b_scales_t16_pe], dim=2)
- wq_b_scales_t16 = wq_b_scales_t16.reshape(-1, self.q_lora_qdim)
- wq_b_scales_full = torch.cat([wq_b_scales_t16, wq_b_scales_iq_raw], dim=0)
-
- return (
- wq_b_full.detach().clone(),
- wq_b_scales_full.detach().clone(),
- rmsnorm_gamma.float().detach().clone(),
- )
-
- def _common_to_tilert_fp8(
- self,
- wq_b: torch.Tensor,
- wq_b_scales_raw: torch.Tensor,
- wq_b_iq: torch.Tensor,
- wq_b_iq_scales_raw: torch.Tensor,
- rmsnorm_gamma: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """Convert common weights to TileRT FP8 MMA layout."""
- # Reshape wq_b: simple split of nope and pe, then concatenate
- wq_b = wq_b.reshape(self.n_local_heads, self.qk_head_dim, self.q_lora_dim)
- wq_b_nope = wq_b[:, : self.qk_nope_head_dim, :].reshape(-1, self.q_lora_dim)
- wq_b_pe = wq_b[:, self.qk_nope_head_dim :, :].reshape(-1, self.q_lora_dim)
- wq_b = torch.cat([wq_b_nope, wq_b_pe], dim=0)
-
- # Process scales: expand and split nope/pe similarly to weights
- m_scale_group = self.block_size // self.repeat
- wq_b_scales_t16 = (
- wq_b_scales_raw.reshape((self.qk_qdim, 1, self.q_lora_qdim))
- .repeat(1, self.repeat, 1)
- .reshape(-1, self.qk_head_dim // m_scale_group, self.q_lora_qdim)
- )
-
- # Split nope and pe parts
- wq_b_scales_nope = wq_b_scales_t16[:, : self.qk_nope_head_dim // m_scale_group, :].reshape(
- [-1, self.q_lora_qdim]
- )
- wq_b_scales_pe = wq_b_scales_t16[:, self.qk_nope_head_dim // m_scale_group :, :].reshape(
- [-1, self.q_lora_qdim]
- )
- wq_b_scales_t16 = torch.cat([wq_b_scales_nope, wq_b_scales_pe], dim=0)
-
- # Process wq_b_iq scales
- wq_b_iq_scales_t16 = (
- wq_b_iq_scales_raw.reshape([self.index_head_qdim, 1, self.q_lora_qdim])
- .repeat([1, self.repeat, 1])
- .reshape((-1, self.q_lora_qdim))
- )
-
- # Concatenate weights and scales
- wq_b_raw = torch.cat([wq_b, wq_b_iq], dim=0)
- page_k = self.q_lora_qdim
- total_out_dim = self.qk_dim + self.index_head_dim
- total_out_qdim = total_out_dim // self.block_size
- wq_b_scales_full = (
- torch.cat(
- [wq_b_scales_t16.to(torch.float32), wq_b_iq_scales_t16.to(torch.float32)], dim=0
- )
- .reshape((total_out_qdim, self.repeat, page_k, self.q_lora_qdim // page_k))
- .permute([0, 2, 1, 3])
- .contiguous()
- .view(torch.float8_e4m3fn)
- )
-
- wq_b_raw = wq_b_raw.reshape(
- [total_out_qdim, 128 // 16, 16, page_k, self.q_lora_dim // 32 // page_k, 32]
- ).permute([0, 3, 1, 4, 2, 5])
- wq_b_raw = WeightsConverter._swizzle_mma_16x32(wq_b_raw)
-
- tilert_wq_b_full = torch.cat(
- [
- wq_b_raw.reshape((total_out_qdim, page_k, -1)),
- wq_b_scales_full.reshape([total_out_qdim, page_k, -1]),
- ],
- -1,
- ).contiguous()
- # TODO: use fp32 scale for glm_5
- tilert_wq_b_full_scales = torch.zeros(1, dtype=torch.bfloat16)
- tilert_q_norm_weight = rmsnorm_gamma.float().detach().clone()
- return tilert_wq_b_full, tilert_wq_b_full_scales, tilert_q_norm_weight
-
- @staticmethod
- def _swizzle_mma_16x16(mat_in: torch.Tensor) -> torch.Tensor:
- assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 16
- # PTX isa fig.88
- pre_shape = mat_in.shape[:-2]
- mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 2).transpose(-4, -3).transpose(-5, -4)
- return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 2).transpose(-3, -2)
-
- @staticmethod
- def _swizzle_mma_16x16_for_16x2048_4pages(mat_in: torch.Tensor) -> torch.Tensor:
- assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 2048
- pre_shape = mat_in.shape[:-2]
- mat_in = mat_in.reshape(*pre_shape, 16, 4, 512).transpose(-3, -2)
- mat_in = mat_in.reshape(*pre_shape, 4, 16, 32, 16).transpose(-3, -2)
- mat_in = RmsnormProjqWqibWeightsConverter._swizzle_mma_16x16(mat_in)
- return mat_in.contiguous()
-
- def _common_to_tilert_fp16mma(
- self,
- wq_b: torch.Tensor,
- wq_b_scale: torch.Tensor,
- wq_b_iq: torch.Tensor,
- wq_b_iq_scale: torch.Tensor,
- q_norm_weight: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """Convert common weights to TileRT FP16 MMA layout."""
- assert self.model_args.arch_name == "glm_5", "Only GLM-5 supports FP16 MMA"
-
- if wq_b_scale.dtype != torch.float32:
- print(
- "Warning: RmsnormProjqWqibWeightsConverter: "
- + f"wq_b_scale.dtype: {wq_b_scale.dtype} "
- + "is not float32, convert to float32."
- )
- wq_b_scale = wq_b_scale.to(torch.float32)
- if wq_b_iq_scale.dtype != torch.float32:
- print(
- "Warning: RmsnormProjqWqibWeightsConverter: "
- + f"wq_b_iq_scale.dtype: {wq_b_iq_scale.dtype} "
- + "is not float32, convert to float32."
- )
- wq_b_iq_scale = wq_b_iq_scale.to(torch.float32)
-
- sms = 128 # use 128 sms for glm_5
- pages = 4
- qk_dim = self.qk_head_dim * self.n_local_heads
- qk_dim_per_sm = qk_dim // sms # 16 per sm
- qk_nope_dim = self.n_local_heads * self.qk_nope_head_dim
- qk_pe_dim = self.n_local_heads * self.qk_rope_head_dim
- iq_dim_per_sm = self.index_head_dim // sms # 32 per sm
-
- wq_b_scale = wq_b_scale.reshape(
- self.n_local_heads, self.qk_head_dim // self.block_size, 1, self.q_lora_qdim
- ).repeat(
- 1, 1, self.block_size, 1
- ) # 2048, 2048//128
-
- wq_b_scale = wq_b_scale.reshape(self.n_local_heads, self.qk_head_dim, -1)
- wq_b_nope_scale = (
- wq_b_scale[:, : self.qk_nope_head_dim, :]
- .reshape(qk_nope_dim // qk_dim_per_sm, qk_dim_per_sm, pages, self.q_lora_qdim // pages)
- .transpose(1, 2) # (96, 4, 16, 4) for glm_5
- )
-
- wq_b_pe_scale = (
- wq_b_scale[:, self.qk_nope_head_dim :, :]
- .reshape(qk_pe_dim // qk_dim_per_sm, qk_dim_per_sm, pages, self.q_lora_qdim // pages)
- .transpose(1, 2) # (32, 4, 16, 4) for glm_5
- )
- wq_b_scale = torch.cat([wq_b_nope_scale, wq_b_pe_scale], dim=0)
- wq_b_scale = wq_b_scale[:, :, 0, :] # (128, 4, 4) for glm_5
-
- wq_b_iq_scale = wq_b_iq_scale.reshape(self.index_head_qdim, 1, self.q_lora_qdim).repeat(
- 1, self.block_size, 1
- ) # (4096, 16) for glm_5
- wq_b_iq_scale = wq_b_iq_scale.reshape(
- sms, iq_dim_per_sm, pages, self.q_lora_qdim // pages
- ).transpose(1, 2)
- wq_b_iq_scale = wq_b_iq_scale[:, :, 0, :] # (128, 4, 4) for glm_5
-
- wq_b_full_scales = (
- torch.cat([wq_b_scale, wq_b_iq_scale], dim=-1).contiguous().view(torch.float8_e4m3fn)
- ) # (128, 4, 8x4) for glm_5
-
- wq_b = wq_b.reshape(self.n_local_heads, self.qk_head_dim, self.q_lora_dim)
- wq_b_nope = wq_b[:, : self.qk_nope_head_dim, :].reshape(-1, self.q_lora_dim) # 8x192, 2048
- wq_b_nope = RmsnormProjqWqibWeightsConverter._swizzle_mma_16x16_for_16x2048_4pages(
- wq_b_nope.reshape(qk_nope_dim // qk_dim_per_sm, qk_dim_per_sm, self.q_lora_dim)
- )
- wq_b_nope = wq_b_nope.reshape(qk_nope_dim // qk_dim_per_sm, pages, qk_dim_per_sm, -1)
- # (96, 4, 16, 512) for glm_5
-
- wq_b_pe = wq_b[:, self.qk_nope_head_dim :, :].reshape(-1, self.q_lora_dim) # 8x64, 2048
- wq_b_pe = RmsnormProjqWqibWeightsConverter._swizzle_mma_16x16_for_16x2048_4pages(
- wq_b_pe.reshape(qk_pe_dim // qk_dim_per_sm, qk_dim_per_sm, self.q_lora_dim)
- )
- wq_b_pe = wq_b_pe.reshape(qk_pe_dim // qk_dim_per_sm, pages, qk_dim_per_sm, -1)
- # (32, 4, 16, 512) for glm_5
- wq_b = torch.cat([wq_b_nope, wq_b_pe], dim=0)
- # (128, 4, 16, 512) for glm_5
-
- wq_b_iq = RmsnormProjqWqibWeightsConverter._swizzle_mma_16x16_for_16x2048_4pages(
- wq_b_iq.reshape(sms, 2, iq_dim_per_sm // 2, self.q_lora_dim)
- )
- wq_b_iq = (
- wq_b_iq.reshape(sms, 2, pages, iq_dim_per_sm // 2, -1)
- .transpose(1, 2)
- .reshape(sms, pages, iq_dim_per_sm, -1)
- )
- # (128, 4, 32, 512) for glm_5
- wq_b = torch.cat([wq_b, wq_b_iq], dim=2)
- wq_b = wq_b.reshape(sms, pages, -1)
- # (128, 4, 48*512) for glm_5
- wq_b_scales_padding = torch.zeros(
- sms,
- pages,
- 128 - wq_b_full_scales.shape[-1],
- dtype=torch.float8_e4m3fn,
- device=wq_b.device,
- ) # append 128-byte aligned scale: (128, 4, 24704) for glm_5
- tilert_wq_b_full = torch.cat(
- [wq_b, wq_b_full_scales, wq_b_scales_padding], dim=-1
- ).contiguous()
- tilert_wq_b_dummy_scales = torch.zeros(1, dtype=torch.bfloat16)
- tilert_q_norm_weight = q_norm_weight.float().detach().clone()
- return tilert_wq_b_full, tilert_wq_b_dummy_scales, tilert_q_norm_weight
-
- def convert_to_bf16(
- self, weights: list[torch.Tensor]
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """Convert common-format weights to TileRT BF16 layout.
-
- Args:
- weights: [q_norm_weight, wq_b, wq_b_scale, wq_b_iq, wq_b_iq_scale].
- """
- with torch.inference_mode():
- wq_b, wq_b_scale, wq_b_iq, wq_b_iq_scale, q_norm_weight = weights
- if self.model_args.arch_name == "glm_5":
- if wq_b_scale.dtype != torch.float32:
- print(
- "Warning: RmsnormProjqWqibWeightsConverter: "
- + f"wq_b_scale.dtype: {wq_b_scale.dtype} "
- + "is not float32, convert to float32."
- )
- wq_b_scales = wq_b_scale.to(torch.float32)
- wq_b_iq_scales = wq_b_iq_scale.to(torch.float32)
- return self._common_to_tilert_bf16(
- wq_b,
- wq_b_scales,
- wq_b_iq,
- wq_b_iq_scales,
- q_norm_weight,
- )
-
- # DS v3.2, use bfloat16 for wq_b_scale and wq_b_iq_scale
- wq_b_scales_bf16 = wq_b_scale.to(torch.bfloat16)
- wq_b_iq_scales_bf16 = wq_b_iq_scale.to(torch.bfloat16)
- return self._common_to_tilert_bf16(
- wq_b,
- wq_b_scales_bf16,
- wq_b_iq,
- wq_b_iq_scales_bf16,
- q_norm_weight,
- )
-
- def convert_to_fp8(
- self, weights: list[torch.Tensor]
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """Convert common-format weights to TileRT FP8 MMA layout.
-
- Args:
- weights: [q_norm_weight, wq_b, wq_b_scale, wq_b_iq, wq_b_iq_scale].
- """
- with torch.inference_mode():
- wq_b, wq_b_scale, wq_b_iq, wq_b_iq_scale, q_norm_weight = weights
- return self._common_to_tilert_fp8(
- wq_b,
- wq_b_scale,
- wq_b_iq,
- wq_b_iq_scale,
- q_norm_weight,
- )
-
- def convert_to_fp16mma(
- self, weights: list[torch.Tensor]
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """Convert common-format weights to TileRT FP16 MMA layout.
-
- Args:
- weights: [q_norm_weight, wq_b, wq_b_scale, wq_b_iq, wq_b_iq_scale].
- """
- with torch.inference_mode():
- wq_b, wq_b_scale, wq_b_iq, wq_b_iq_scale, q_norm_weight = weights
- return self._common_to_tilert_fp16mma(
- wq_b,
- wq_b_scale,
- wq_b_iq,
- wq_b_iq_scale,
- q_norm_weight,
- )
-
-
-@dataclass
-class RmsnormProjqWqibRefWeightsAlias:
- """Reference weights alias for RmsnormProjqWqib."""
-
- rmsnorm_gamma = "self_attn.q_a_layernorm.weight"
- wqb_weights = "self_attn.q_b_proj.weight"
- wqb_scales = "self_attn.q_b_proj.weight_scale_inv"
- wi_weights = "self_attn.indexer.wq_b.weight"
- wi_scales = "self_attn.indexer.wq_b.weight_scale_inv"
-
- @property
- def ref_tensor_alias(self) -> list[str]:
- return [
- self.rmsnorm_gamma,
- self.wqb_weights,
- self.wqb_scales,
- self.wi_weights,
- self.wi_scales,
- ]
-
- def __call__(self) -> list[str]:
- return self.ref_tensor_alias
-
-
-@dataclass
-class RmsnormProjqWqibTilertWeightsAlias:
- """TileRT weights alias for RmsnormProjqWqib."""
-
- rmsnorm_gamma = "q_rmsnorm_gamma"
- wqb_weights = "wqb_weights"
- wqb_scales = "wqb_scales"
- wi_weights = "wi_weights"
- wi_scales = "wi_scales"
-
- @property
- def tilert_tensor_alias(self) -> list[str]:
- return [
- self.rmsnorm_gamma,
- self.wqb_weights,
- self.wqb_scales,
- self.wi_weights,
- self.wi_scales,
- ]
-
- def __call__(self) -> list[str]:
- return self.tilert_tensor_alias
-
-
-class RmsnormProjqWqib(TileRTModule):
- """RmsnormProjqWqib module: RMSNorm + Q projection (wq_b + wq_b_iq)."""
-
- def __init__(
- self,
- model_args: ModelArgs,
- device_id: int,
- num_devices: int,
- ref_weights_alias: RmsnormProjqWqibRefWeightsAlias | None = None,
- ):
- super().__init__(
- self.__class__.__name__,
- model_args=model_args,
- device_id=device_id,
- num_devices=num_devices,
- )
-
- self.tilert_weights_alias = RmsnormProjqWqibTilertWeightsAlias()
- self.ref_weights_alias = (
- ref_weights_alias
- if ref_weights_alias is not None
- else RmsnormProjqWqibRefWeightsAlias()
- )
-
- self.n_local_heads = model_args.n_heads // num_devices
- self.q_lora_rank = model_args.q_lora_rank
- self.index_n_heads = model_args.index_n_heads
- self.head_dim = model_args.index_head_dim
- self.index_head_dim = model_args.index_n_heads * model_args.index_head_dim
- self.n_heads = model_args.n_heads
- self.qk_head_dim = model_args.qk_nope_head_dim + model_args.qk_rope_head_dim
- self.qk_local_dim = self.qk_head_dim * self.n_local_heads
- self.qk_nope_head_dim = model_args.qk_nope_head_dim
- self.qk_rope_head_dim = model_args.qk_rope_head_dim
-
- # quantize block size
- self.block_size = model_args.block_size
- self.q_lora_qdim = self.q_lora_rank // self.block_size
- self.qk_local_qdim = self.qk_local_dim // self.block_size
- self.index_head_qdim = self.index_head_dim // self.block_size
- self.eps = model_args.eps
-
- self.ref_q_norm: torch.Tensor | None = None
- self.ref_wq_b: torch.Tensor | None = None
- self.ref_wq_b_iq: torch.Tensor | None = None
-
- self.tilert_wq_b_full: torch.Tensor | None = None
- self.tilert_wq_b_full_scales: torch.Tensor | None = None
- self.tilert_q_norm_weight: torch.Tensor | None = None
-
- self.q_nope: torch.Tensor | None = None
- self.q_pe: torch.Tensor | None = None
- self.iq: torch.Tensor | None = None
-
- self.profile_logs: torch.Tensor | None = None
-
- def get_weights_list(self) -> list[torch.Tensor]:
- return [self.tilert_q_norm_weight, self.tilert_wq_b_full, self.tilert_wq_b_full_scales]
-
- def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
- """Device sharding."""
- gamma = weights_map[self.ref_weights_alias.rmsnorm_gamma][None, ...].repeat(
- self.num_devices, 1
- )
-
- sharded_wqb_weights = weights_map[self.ref_weights_alias.wqb_weights].reshape(
- self.num_devices, self.qk_local_dim, self.q_lora_rank
- )
- sharded_wi_weights = weights_map[self.ref_weights_alias.wi_weights][None, ...].repeat(
- self.num_devices, 1, 1
- )
-
- sharded_wqb_scales = weights_map[self.ref_weights_alias.wqb_scales].reshape(
- self.num_devices, self.qk_local_qdim, self.q_lora_qdim
- )
- sharded_wi_scales = weights_map[self.ref_weights_alias.wi_scales][None, ...].repeat(
- self.num_devices, 1, 1
- )
-
- return {
- self.tilert_weights_alias.rmsnorm_gamma: gamma,
- self.tilert_weights_alias.wqb_weights: sharded_wqb_weights,
- self.tilert_weights_alias.wqb_scales: sharded_wqb_scales,
- self.tilert_weights_alias.wi_weights: sharded_wi_weights,
- self.tilert_weights_alias.wi_scales: sharded_wi_scales,
- }
-
- def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
- """Initialize reference weights from common-format state dict."""
- self.ref_q_norm = state_dict[self.ref_weights_alias.rmsnorm_gamma]
- qk_local_dim_start = self.qk_local_dim * self.device_id
- qk_local_qdim_start = qk_local_dim_start // self.block_size
- qk_local_dim_end = qk_local_dim_start + self.qk_local_dim
- qk_local_qdim_end = qk_local_dim_end // self.block_size
- wq_b = weight_dequant(
- state_dict[self.ref_weights_alias.wqb_weights][qk_local_dim_start:qk_local_dim_end],
- state_dict[self.ref_weights_alias.wqb_scales][qk_local_qdim_start:qk_local_qdim_end],
- )
- wq_b_iq = weight_dequant(
- state_dict[self.ref_weights_alias.wi_weights],
- state_dict[self.ref_weights_alias.wi_scales],
- )
- self.ref_wq_b = wq_b.contiguous()
- self.ref_wq_b_iq = wq_b_iq.contiguous()
-
- def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
- """Initialize TileRT weights from common-format state dict."""
- weights = [
- state_dict[_k]
- for _k in [
- self.tilert_weights_alias.wqb_weights,
- self.tilert_weights_alias.wqb_scales,
- self.tilert_weights_alias.wi_weights,
- self.tilert_weights_alias.wi_scales,
- self.tilert_weights_alias.rmsnorm_gamma,
- ]
- ]
- assert self.algorithm is not None, "Algorithm is not set"
- self.tilert_wq_b_full, self.tilert_wq_b_full_scales, self.tilert_q_norm_weight = (
- RmsnormProjqWqibWeightsConverter(self.model_args, self.num_devices).dispatch(
- self.algorithm, weights
- )
- )
-
- def init_random_weights(self) -> None:
- """Initialize random reference and TileRT weights for testing."""
- q_norm = torch.randn(self.q_lora_rank, dtype=torch.float32)
- wq_b = torch.randn(
- self.num_devices * self.qk_local_dim, self.q_lora_rank, dtype=torch.bfloat16
- ).to(torch.float8_e4m3fn)
- scale_dtype = torch.float32 if self.model_args.arch_name == "glm_5" else torch.bfloat16
- wq_b_scale = torch.randn(
- self.num_devices * self.qk_local_qdim, self.q_lora_qdim, dtype=scale_dtype
- )
- wq_b_iq = torch.randn(self.index_head_dim, self.q_lora_rank, dtype=torch.bfloat16).to(
- torch.float8_e4m3fn
- )
- wq_b_iq_scale = torch.randn(self.index_head_qdim, self.q_lora_qdim, dtype=scale_dtype)
- ref_state = {
- self.ref_weights_alias.rmsnorm_gamma: q_norm,
- self.ref_weights_alias.wqb_weights: wq_b,
- self.ref_weights_alias.wqb_scales: wq_b_scale,
- self.ref_weights_alias.wi_weights: wq_b_iq,
- self.ref_weights_alias.wi_scales: wq_b_iq_scale,
- }
-
- self.init_reference_weights(ref_state)
- self.init_tilert_weights(
- {_k: _v[self.device_id] for _k, _v in self.device_sharding(ref_state).items()}
- )
-
- def init_tilert_vars(self, batch_size: int, seq_len: int) -> None:
- """Allocate TileRT output buffers."""
- self.q_nope = torch.zeros(
- batch_size, seq_len, self.n_local_heads, self.qk_nope_head_dim, dtype=torch.bfloat16
- )
- self.q_pe = torch.zeros(
- batch_size, seq_len, self.n_local_heads, self.qk_rope_head_dim, dtype=torch.bfloat16
- )
- self.iq = torch.zeros(
- batch_size, seq_len, self.index_n_heads, self.head_dim, dtype=torch.bfloat16
- )
- self.profile_logs = get_profile_log_tensor()
- self.is_var_init = True
-
- def golden_forward(self, q: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """Reference forward: RMSNorm + linear projections."""
- assert self.ref_q_norm is not None
- assert self.ref_wq_b is not None
- assert self.ref_wq_b_iq is not None
-
- bsz, seqlen, _ = q.shape
- if bsz != 1 or seqlen not in [1, 2, 4]:
- raise ValueError(f"Invalid batch size or sequence length: bsz={bsz}, seqlen={seqlen}")
-
- qr = torch.nn.functional.rms_norm(q.float(), [q.size(-1)], self.ref_q_norm, self.eps).to(
- q.dtype
- )
-
- q = torch.matmul(qr, self.ref_wq_b.T)
- q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
- q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
- q_idx = torch.matmul(qr, self.ref_wq_b_iq.T)
- q_idx = rearrange(q_idx, "b s (h d) -> b s h d", d=self.head_dim)
- return q_nope, q_pe, q_idx
-
- def tilert_forward(self, q: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- assert self.tilert_wq_b_full is not None
- assert self.tilert_wq_b_full_scales is not None
- assert self.tilert_q_norm_weight is not None
- assert self.q_nope is not None
- assert self.q_pe is not None
- assert self.iq is not None
- assert self.profile_logs is not None
-
- bsz, seqlen, _ = q.shape
- if bsz != 1 or seqlen not in [1, 2, 4]:
- raise ValueError(f"Invalid batch size or sequence length: bsz={bsz}, seqlen={seqlen}")
-
- assert self.algorithm is not None, "Algorithm is not set"
-
- rmsnorm_projq_wqib_op(
- q,
- self.tilert_wq_b_full,
- self.tilert_wq_b_full_scales,
- self.tilert_q_norm_weight,
- self.q_nope,
- self.q_pe,
- self.iq,
- self.profile_logs,
- self.algorithm.value,
- )
-
- if self.flag_enable_profiling_log:
- torch.cuda.synchronize()
- parse_profile_log_tensor(
- self.profile_logs, self.get_profile_log_path(), [(self.op_name, 0.0)]
- )
- return self.q_nope, self.q_pe, self.iq
diff --git a/python/models/deepseek_v3_2/ops/rmsnorm_projx_wqkvia.py b/python/models/deepseek_v3_2/ops/rmsnorm_projx_wqkvia.py
deleted file mode 100644
index d6538ed..0000000
--- a/python/models/deepseek_v3_2/ops/rmsnorm_projx_wqkvia.py
+++ /dev/null
@@ -1,1095 +0,0 @@
-"""RMSNormProjxWqkvia operation module."""
-
-from collections.abc import Callable
-from dataclasses import dataclass
-from enum import Enum
-
-# from typing import Any
-import torch
-
-from tilert.models.base import TileRTModule, TilertWeightsConverter
-from tilert.models.common import weight_dequant
-from tilert.models.deepseek_v3_2.model_args import ModelArgs
-from tilert.models.deepseek_v3_2.ops.rmsnorm_quant import rmsnorm_quant
-from tilert.profiler.utils import parse_profile_log_tensor
-from tilert.utils import get_profile_log_tensor
-
-__all__ = [
- "RMSNormProjQAKVAKIWeightsConverter",
- "RMSNormProjxWqkviaAlgorithm",
- "RMSNormProjxWqkvia",
- "RMSNormProjxWqkviaRefWeightsAlias",
- "RMSNormProjxWqkviaTilertWeightsAlias",
- "rmsnorm_projx_wqkvia",
- "projx_wqkvia",
-]
-
-
-def rmsnorm_projx_wqkvia(
- x_in: torch.Tensor,
- wqkv_a: torch.Tensor,
- wqkv_a_scales: torch.Tensor,
- rmsnorm_gamma: torch.Tensor,
- cur_pos: torch.Tensor,
- q_out: torch.Tensor,
- kv_out: torch.Tensor,
- pe_cache: torch.Tensor,
- ki_out: torch.Tensor,
- x_rmsnorm_out: torch.Tensor,
- profile_logs: torch.Tensor,
-) -> None:
- """
- rmsnorm_projx_wqkvia operation.
-
- Args:
- x_in: Input tensor.
- wqkv_a: QKV weights.
- wqkv_a_scales: QKV scales.
- rmsnorm_gamma: RMSNorm gamma.
- cur_pos: Current position.
- q_out: Q output tensor.
- kv_out: KV output tensor.
- pe_cache: PE cache tensor.
- ki_out: Ki output tensor.
- x_rmsnorm_out: RMSNorm output tensor.
- profile_logs: Profile logs tensor.
- """
- torch.ops.tilert.rmsnorm_proj_qa_kva_ki_op(
- x_in,
- wqkv_a,
- wqkv_a_scales,
- rmsnorm_gamma,
- cur_pos,
- q_out,
- kv_out,
- pe_cache,
- ki_out,
- x_rmsnorm_out,
- profile_logs,
- )
-
-
-def projx_wqkvia(
- x_quant: torch.Tensor,
- x_scale: torch.Tensor,
- wqkvia: torch.Tensor,
- cur_pos: torch.Tensor,
- out_q: torch.Tensor,
- out_kv: torch.Tensor,
- pe_cache: torch.Tensor,
- out_ki: torch.Tensor,
- profile_logs: torch.Tensor,
-) -> None:
- """
- Define the ProjXWQKVIa operation.
-
- Args:
- x_quant: Input tensor.
- x_scale: Weight tensor.
- wqkvia: Weight tensor.
- cur_pos: Current position tensor.
- out_q: Output tensor.
- out_kv: Output tensor.
- pe_cache: Output tensor.
- out_ki: Output tensor.
- profile_logs: Profile logs tensor.
- """
- dim = x_quant.shape[-1]
- if dim == 6144:
- func_call = torch.ops.tilert.projx_wqkvia_glm5
- elif dim == 7168:
- func_call = torch.ops.tilert.projx_wqkvia_op
- else:
- raise ValueError(f"Unsupported dimension: {dim}")
- func_call(x_quant, x_scale, wqkvia, cur_pos, out_q, out_kv, pe_cache, out_ki, profile_logs)
-
-
-class RMSNormProjxWqkviaAlgorithm(Enum):
- """RMSNormProjxWqkvia algorithm"""
-
- GENERAL = "general" # fused
- DECOUPLED = "decoupled" # rmsnorm_quant + projx_wqkvia
-
-
-class RMSNormProjQAKVAKIWeightsConverter:
- """Weights converter class."""
-
- @staticmethod
- def _swizzle_mma_16x32(mat_in: torch.Tensor) -> torch.Tensor:
- assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 32
- # PTX isa fig.88
- pre_shape = mat_in.shape[:-2]
- mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 4).transpose(-4, -3).transpose(-5, -4)
- return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 4).transpose(-3, -2)
-
- @staticmethod
- def tilert_to_common(
- tilert_wqkv_a: torch.Tensor,
- tilert_wqkv_a_scales: torch.Tensor,
- tilert_attn_norm_weight: torch.Tensor,
- ) -> tuple[
- torch.Tensor,
- torch.Tensor,
- torch.Tensor,
- torch.Tensor,
- torch.Tensor,
- torch.Tensor,
- torch.Tensor,
- ]:
- """
- Convert tilert weights to common weights.
-
- Args:
- tilert_wqkv_a: Tilert weight tensor.
- tilert_wqkv_a_scales: Tilert weight scale tensor.
- tilert_attn_norm_weight: Tilert attention norm weight tensor.
- Returns:
- tuple: Common weights.
- """
- wq_a = tilert_wqkv_a[:1536] # 1536, 7168
- wkv_a = tilert_wqkv_a[1536 : 1536 + 576] # 576, 7168
- wk = tilert_wqkv_a[1536 + 576 :] # 128, 7168
-
- wqkv_a_scales_0 = tilert_wqkv_a_scales[:128, :].reshape(16, 8, 64)
- wqkv_a_scales_0 = wqkv_a_scales_0[:, 0, :].reshape(16, 64)
- wqkv_a_scales_1 = tilert_wqkv_a_scales[128:129, :] # 1, 64
- wqkv_a_scales_2 = tilert_wqkv_a_scales[129:, :] # 1, 64
- wqkv_a_scales_swizzled = torch.cat(
- [wqkv_a_scales_0, wqkv_a_scales_1, wqkv_a_scales_2], dim=0
- )
- wqkv_scales = torch.zeros(
- (18, 56), dtype=torch.bfloat16, device=tilert_wqkv_a_scales.device
- )
-
- for i in range(64):
- if ((i % 8) * 8 + i // 8) < 56:
- wqkv_scales[:, ((i % 8) * 8 + i // 8)] = wqkv_a_scales_swizzled[:, i]
- wq_a_scale = wqkv_scales[:12, :] # 12, 56
- wkv_a_scale = wqkv_scales[12:17, :] # 5, 56
- wk_scale = wqkv_scales[17:, :] # 1, 56
-
- attn_norm_weight = tilert_attn_norm_weight
- return wq_a, wq_a_scale, wkv_a, wkv_a_scale, wk, wk_scale, attn_norm_weight
-
- @staticmethod
- def common_to_tilert(
- wq_a: torch.Tensor,
- wq_a_scale: torch.Tensor,
- wkv_a: torch.Tensor,
- wkv_a_scale: torch.Tensor,
- wk: torch.Tensor,
- wk_scale: torch.Tensor,
- attn_norm_weight: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """
- Convert common weights to tilert weights.
-
- Args:
- wq_a: Common weight tensor.
- wq_a_scale: Common weight scale tensor.
- wkv_a: Common weight tensor.
- wkv_a_scale: Common weight scale tensor.
- wk: Common weight tensor.
- wk_scale: Common weight scale tensor.
- attn_norm_weight: Common attention norm weight tensor.
- Returns:
- tuple: Tilert weights.
- """
- wqkv_a = torch.cat([wq_a, wkv_a, wk], dim=0)
- wqkv_a_scales_raw = torch.cat([wq_a_scale, wkv_a_scale, wk_scale], dim=0)
-
- wqkv_a_scales = torch.zeros((18, 64), dtype=torch.bfloat16, device=wq_a_scale.device)
- for i in range(64):
- wqkv_a_scales[:, i] = wqkv_a_scales_raw[:, ((i % 8) * 8 + i // 8) % 56]
- if ((i % 8) * 8 + i // 8) >= 56:
- wqkv_a_scales[:, i] = 0.0
- wqkv_a_scales_0 = wqkv_a_scales[:16, :]
- wqkv_a_scales_1 = wqkv_a_scales[16:17, :]
- wqkv_a_scales_2 = wqkv_a_scales[17:, :]
-
- wqkv_a_scales_0 = wqkv_a_scales_0.reshape((16, 1, 64)).repeat(1, 8, 1).reshape(-1, 64)
- wqkv_a_scales = torch.cat([wqkv_a_scales_0, wqkv_a_scales_1, wqkv_a_scales_2], dim=0)
- assert wqkv_a_scales.shape == (130, 64)
- return wqkv_a.contiguous(), wqkv_a_scales.contiguous(), attn_norm_weight.clone()
-
- @staticmethod
- def common_to_tilert_fp8(
- wq_a: torch.Tensor,
- wq_a_scale: torch.Tensor,
- wkv_a: torch.Tensor,
- wkv_a_scale: torch.Tensor,
- wk: torch.Tensor,
- wk_scale: torch.Tensor,
- attn_norm_weight: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Convert common weights to tilert weights.
-
- Args:
- wq_a: Common weight tensor.
- wq_a_scale: Common weight scale tensor.
- wkv_a: Common weight tensor.
- wkv_a_scale: Common weight scale tensor.
- wk: Common weight tensor.
- wk_scale: Common weight scale tensor.
- attn_norm_weight: Common attention norm weight tensor.
- Returns:
- tuple: Tilert fp8 weights.
- """
- wq_a_raw: torch.Tensor = wq_a.detach().clone()
- wkv_a_raw: torch.Tensor = wkv_a.detach().clone()
- wq_a_raw = torch.cat([wq_a_raw, wkv_a_raw[:512], wk, wkv_a_raw[512:]], dim=0)
-
- wq_a_raw = wq_a_raw.reshape(35, 64, 14, 512)
- wq_a_raw = wq_a_raw.permute(0, 2, 1, 3)
-
- wq_a_raw = wq_a_raw.reshape(35, 14, 16, 4, 4, 128)
- wq_a_copy = wq_a_raw.contiguous().clone()
- wq_a_raw[:, :, 1::2, :, :, :64] = wq_a_copy[:, :, 1::2, :, :, 64:]
- wq_a_raw[:, :, 1::2, :, :, 64:] = wq_a_copy[:, :, 1::2, :, :, :64]
- wq_a_raw = wq_a_raw.reshape(35, 14, 16, 4, 4, 2, 64)
- wq_a_copy = wq_a_raw.contiguous().clone()
- wq_a_raw[:, :, :, 2:, :, :, :32] = wq_a_copy[:, :, :, 2:, :, :, 32:]
- wq_a_raw[:, :, :, 2:, :, :, 32:] = wq_a_copy[:, :, :, 2:, :, :, :32]
- wq_a_raw = wq_a_raw.reshape(35, 14, 16, 4, 4, 2, 2, 32)
- wq_a_copy = wq_a_raw.contiguous().clone()
- wq_a_raw[:, :, :, 1::2, :, :, :, :16] = wq_a_copy[:, :, :, 1::2, :, :, :, 16:]
- wq_a_raw[:, :, :, 1::2, :, :, :, 16:] = wq_a_copy[:, :, :, 1::2, :, :, :, :16]
-
- wq_a_raw = wq_a_raw.reshape(35, 14, 16, 4, 4, 128)
- wq_a_raw = wq_a_raw.permute(0, 1, 4, 2, 3, 5).reshape(35, 14, -1).contiguous()
- wq_a_raw = wq_a_raw.reshape(35, 14, -1).contiguous()
-
- wq_s_raw: torch.Tensor = wq_a_scale.detach().clone()
- wkv_s_raw: torch.Tensor = wkv_a_scale.detach().clone()
- wq_s_raw = torch.cat([wq_s_raw, wkv_s_raw[:4], wk_scale, wkv_s_raw[4:]], dim=0)
- wq_s_raw = wq_s_raw.reshape(18, 1, 14, 4).repeat(1, 2, 1, 1).reshape(36, 1, 14, 4)
- wq_s_raw = wq_s_raw[:35].reshape(35, 14, -1).contiguous()
- wq_s_raw = wq_s_raw.view(torch.float8_e4m3fn)
- wq_as_raw = torch.cat([wq_a_raw, wq_s_raw], dim=-1)
-
- return wq_as_raw.contiguous(), attn_norm_weight.clone()
-
- @staticmethod
- def common_to_tilert_native_bf16(
- wq_a: torch.Tensor,
- wq_a_scale: torch.Tensor,
- wkv_a: torch.Tensor,
- wkv_a_scale: torch.Tensor,
- wk: torch.Tensor,
- wk_scale: torch.Tensor,
- attn_norm_weight: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Convert common weights to weights for tilert native bf16 op.
-
- Args:
- wq_a: Common weight tensor.
- wq_a_scale: Common weight scale tensor.
- wkv_a: Common weight tensor.
- wkv_a_scale: Common weight scale tensor.
- wk: Common weight tensor.
- wk_scale: Common weight scale tensor.
- attn_norm_weight: Common attention norm weight tensor.
- Returns:
- tuple: Tilert weights for native bf16 op.
- """
- wq_a_scale = wq_a_scale.reshape((12, 56, 1)).repeat(1, 1, 128).reshape((12, 1, 7168))
- wq_a_scale = wq_a_scale.repeat(1, 128, 1).reshape((1536, 7168))
- wkv_a_scale = wkv_a_scale.reshape((5, 56, 1)).repeat(1, 1, 128).reshape((5, 1, 7168))
- wkv_a_scale = wkv_a_scale.repeat(1, 128, 1).reshape((-1, 7168))
- wkv_a_scale = wkv_a_scale[:576]
- wk_scale = wk_scale.reshape((1, 56, 1)).repeat(1, 1, 128).reshape((1, 1, 7168))
- wk_scale = wk_scale.repeat(1, 128, 1).reshape((128, 7168))
- wq_a = wq_a.reshape((1536, 7168)).float() * wq_a_scale.float()
- wkv_a = wkv_a.reshape((576, 7168)).float() * wkv_a_scale.float()
- wk = wk.reshape((128, 7168)).float() * wk_scale.float()
- weights = torch.cat([wq_a, wkv_a, wk], dim=0)
- assert weights.shape == (1536 + 576 + 128, 7168)
- return weights.to(torch.bfloat16).contiguous(), attn_norm_weight.clone()
-
- @staticmethod
- def common_to_tilert_native_bf16_warp_gemv(
- wq_a: torch.Tensor,
- wq_a_scale: torch.Tensor,
- wkv_a: torch.Tensor,
- wkv_a_scale: torch.Tensor,
- wk: torch.Tensor,
- wk_scale: torch.Tensor,
- attn_norm_weight: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Convert common weights to weights for tilert native bf16 warp gemv op.
-
- Args:
- wq_a: Common weight tensor.
- wq_a_scale: Common weight scale tensor.
- wkv_a: Common weight tensor.
- wkv_a_scale: Common weight scale tensor.
- wk: Common weight tensor.
- wk_scale: Common weight scale tensor.
- attn_norm_weight: Common attention norm weight tensor.
- Returns:
- tuple: Tilert weights for native bf16 warp gemv op.
- """
- wq_a_scale = wq_a_scale.reshape((12, 56, 1)).repeat(1, 1, 128).reshape((12, 1, 7168))
- wq_a_scale = wq_a_scale.repeat(1, 128, 1).reshape((1536, 7168))
- wkv_a_scale = wkv_a_scale.reshape((5, 56, 1)).repeat(1, 1, 128).reshape((5, 1, 7168))
- wkv_a_scale = wkv_a_scale.repeat(1, 128, 1).reshape((-1, 7168))
- wkv_a_scale = wkv_a_scale[:576]
- wk_scale = wk_scale.reshape((1, 56, 1)).repeat(1, 1, 128).reshape((1, 1, 7168))
- wk_scale = wk_scale.repeat(1, 128, 1).reshape((128, 7168))
- wq_a = wq_a.reshape((1536, 7168)).float() * wq_a_scale.float()
- wkv_a = wkv_a.reshape((576, 7168)).float() * wkv_a_scale.float()
- wk = wk.reshape((128, 7168)).float() * wk_scale.float()
- # concatenate the weights
- weights = torch.cat([wq_a, wkv_a, wk], dim=0)
- assert weights.shape == (1536 + 576 + 128, 7168)
-
- weights = weights.reshape(140, 16, 7, 1024)
- weights = weights.transpose(1, 2) # 140, 7, 16, 1024
- return weights.to(torch.bfloat16).contiguous(), attn_norm_weight.clone()
-
- @staticmethod
- def common_to_tilert_dequant_bf16(
- wq_a: torch.Tensor,
- wq_a_scale: torch.Tensor,
- wkv_a: torch.Tensor,
- wkv_a_scale: torch.Tensor,
- wk: torch.Tensor,
- wk_scale: torch.Tensor,
- attn_norm_weight: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Convert common weights to weights for tilert dequant bf16 op.
-
- Args:
- wq_a: Common weight tensor.
- wq_a_scale: Common weight scale tensor.
- wkv_a: Common weight tensor.
- wkv_a_scale: Common weight scale tensor.
- wk: Common weight tensor.
- wk_scale: Common weight scale tensor.
- attn_norm_weight: Common attention norm weight tensor.
- Returns:
- tuple: Tilert weights for dequant bf16 op.
- """
- wq_a = wq_a.reshape((384, 4, 7168))
- wkv_a = wkv_a.reshape((144, 4, 7168))
- wk = wk.reshape((32, 4, 7168))
- wqkv = torch.cat([wq_a, wkv_a, wk], dim=0).reshape(140, 4, 4 * 7168)
-
- wq_a_scale = wq_a_scale.reshape((12, 1, 56)).repeat(1, 32, 1).reshape((384, 1, 56))
- wkv_a_scale = wkv_a_scale.reshape((5, 1, 56)).repeat(1, 32, 1).reshape((160, 1, 56))[:144]
- wk_scale = wk_scale.reshape((1, 1, 56)).repeat(1, 32, 1).reshape((32, 1, 56))
- wqkv_scales = torch.cat([wq_a_scale, wkv_a_scale, wk_scale], dim=0).reshape(140, 4, 56)
- wqkv_scales_swizzled = torch.zeros(140, 4, 64, dtype=torch.bfloat16, device=wq_a.device)
- # swizzle
- for i in range(64):
- wqkv_scales_swizzled[..., i] = wqkv_scales[..., ((i % 8) * 8 + i // 8) % 56]
- weights = torch.zeros(
- 140, 4, 4 * 7168 + 64 * 2, dtype=torch.float8_e4m3fn, device=wq_a.device
- )
- weights_part = weights[:, :, : 4 * 7168]
- scales_part = weights[:, :, 4 * 7168 :]
- weights_part.copy_(wqkv)
- scales_part.copy_(wqkv_scales_swizzled.view(dtype=torch.float8_e4m3fn))
- return weights.contiguous(), attn_norm_weight.clone()
-
- @staticmethod
- def common_to_tilert_fp8_mma(
- wq_a: torch.Tensor,
- wq_a_scale: torch.Tensor,
- wkv_a: torch.Tensor,
- wkv_a_scale: torch.Tensor,
- wk: torch.Tensor,
- wk_scale: torch.Tensor,
- rmsnorm_gamma: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Convert common weights to weights for tilert fp8 mma op.
-
- Args:
- wq_a: Common weight tensor.
- wq_a_scale: Common weight scale tensor.
- wkv_a: Common weight tensor.
- wkv_a_scale: Common weight scale tensor.
- wk: Common weight tensor.
- wk_scale: Common weight scale tensor.
- rmsnorm_gamma: Common rmsnorm gamma tensor.
- Returns:
- tuple: Tilert weights for fp8 mma op.
- """
- assert wq_a.shape == (1536, 7168)
- assert wq_a_scale.shape == (12, 56)
- assert wkv_a.shape == (576, 7168)
- assert wkv_a_scale.shape == (5, 56)
- assert wk.shape == (128, 7168)
- assert wk_scale.shape == (1, 56)
- wq_a = wq_a.reshape(96, 16, 7168)
- wq_a_scale = wq_a_scale.reshape(12, 1, 56).repeat(1, 8, 1).reshape(96, 56)
- wkv_a = wkv_a.reshape(36, 16, 7168)
- wkv_a_scale = wkv_a_scale.reshape(5, 1, 56).repeat(1, 8, 1).reshape(40, 56)
- wkv_a_scale = wkv_a_scale[:36]
-
- wk = wk.reshape(8, 16, 7168)
- wk_scale = wk_scale.reshape(1, 1, 56).repeat(1, 8, 1).reshape(8, 56)
- wqkvia = torch.cat([wq_a, wkv_a, wk], dim=0) # 140, 7168
- wqkvia_scale = torch.cat([wq_a_scale, wkv_a_scale, wk_scale], dim=0) # 140, 56
-
- wqkvia_0 = wqkvia[..., :2048]
- wqkvia_0_scale = wqkvia_scale[..., :16].contiguous().view(torch.float8_e4m3fn)
- wqkvia_1 = wqkvia[..., 2048:4096]
- wqkvia_1_scale = wqkvia_scale[..., 16:32].contiguous().view(torch.float8_e4m3fn)
- wqkvia_2 = wqkvia[..., 4096:6144]
- wqkvia_2_scale = wqkvia_scale[..., 32:48].contiguous().view(torch.float8_e4m3fn)
- wqkvia_3 = wqkvia[..., 6144:7168]
- wqkvia_3_scale = wqkvia_scale[..., 48:56].contiguous().view(torch.float8_e4m3fn)
-
- wqkvia_0 = wqkvia_0.reshape(140, 16, 64, 32).transpose(1, 2)
- wqkvia_0 = RMSNormProjQAKVAKIWeightsConverter._swizzle_mma_16x32(wqkvia_0)
- wqkvia_0 = wqkvia_0.reshape(140, 16 * 2048)
-
- wqkvia_1 = wqkvia_1.reshape(140, 16, 64, 32).transpose(1, 2)
- wqkvia_1 = RMSNormProjQAKVAKIWeightsConverter._swizzle_mma_16x32(wqkvia_1)
- wqkvia_1 = wqkvia_1.reshape(140, 16 * 2048)
-
- wqkvia_2 = wqkvia_2.reshape(140, 16, 64, 32).transpose(1, 2)
- wqkvia_2 = RMSNormProjQAKVAKIWeightsConverter._swizzle_mma_16x32(wqkvia_2)
- wqkvia_2 = wqkvia_2.reshape(140, 16 * 2048)
-
- wqkvia_3 = wqkvia_3.reshape(140, 16, 32, 32).transpose(1, 2)
- wqkvia_3 = RMSNormProjQAKVAKIWeightsConverter._swizzle_mma_16x32(wqkvia_3)
- wqkvia_3 = wqkvia_3.reshape(140, 16 * 1024)
- padding_scale0 = torch.zeros((140, 48), dtype=torch.bfloat16, device=wq_a.device).view(
- torch.float8_e4m3fn
- )
- padding_scale1 = torch.zeros((140, 48), dtype=torch.bfloat16, device=wq_a.device).view(
- torch.float8_e4m3fn
- )
- padding_scale2 = torch.zeros((140, 48), dtype=torch.bfloat16, device=wq_a.device).view(
- torch.float8_e4m3fn
- )
- padding_scale3 = torch.zeros((140, 56), dtype=torch.bfloat16, device=wq_a.device).view(
- torch.float8_e4m3fn
- )
- wqkvia = torch.cat(
- [
- wqkvia_0,
- wqkvia_0_scale,
- padding_scale0,
- wqkvia_1,
- wqkvia_1_scale,
- padding_scale1,
- wqkvia_2,
- wqkvia_2_scale,
- padding_scale2,
- wqkvia_3,
- wqkvia_3_scale,
- padding_scale3,
- ],
- dim=1,
- )
-
- return wqkvia.contiguous(), rmsnorm_gamma.contiguous()
-
-
-class RMSNormProjxWqkviaWeightsConverter(TilertWeightsConverter):
- """RMSNormProjxWqkvia weights converter"""
-
- @staticmethod
- def _swizzle_qmma_16x32(mat_in: torch.Tensor) -> torch.Tensor:
- assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 32
- assert mat_in.dtype == torch.float8_e4m3fn
- # PTX isa fig.88
- pre_shape = mat_in.shape[:-2]
- mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 4).transpose(-4, -3).transpose(-5, -4)
- return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 4).transpose(-3, -2)
-
- def convert_to_general(self, weights: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Convert the weights to general format.
-
- Args:
- weights: List of weights.
-
- Returns:
- Tuple of weights.
- """
- # Specialized for DS v3.2 model
- args = self.model_args
- assert (
- args.arch_name == "deepseek_v3_2"
- ), f"arch_name must be deepseek_v3_2, but got {args.arch_name}"
- with torch.inference_mode():
- x_rmsnorm_gamma, wq_a, wq_a_scale, wkv_a, wkv_a_scale, wk, wk_scale = weights
- q_lora_rank_scale_dim = args.q_lora_rank // args.block_size
- kv_lora_rank_scale_dim = args.kv_lora_rank // args.block_size + 1
- x_scale_dim = args.dim // args.block_size
-
- wq_a_scale = (
- wq_a_scale.reshape((q_lora_rank_scale_dim, x_scale_dim, 1))
- .repeat(1, 1, args.block_size)
- .reshape((q_lora_rank_scale_dim, 1, args.dim))
- )
- wq_a_scale = wq_a_scale.repeat(1, args.block_size, 1).reshape(
- (args.q_lora_rank, args.dim)
- )
- wkv_a_scale = (
- wkv_a_scale.reshape((kv_lora_rank_scale_dim, x_scale_dim, 1))
- .repeat(1, 1, args.block_size)
- .reshape((kv_lora_rank_scale_dim, 1, args.dim))
- )
- wkv_a_scale = wkv_a_scale.repeat(1, args.block_size, 1).reshape((-1, args.dim))
- wkv_a_scale = wkv_a_scale[: args.kv_lora_rank + args.qk_rope_head_dim]
- wk_scale = (
- wk_scale.reshape((1, x_scale_dim, 1))
- .repeat(1, 1, args.block_size)
- .reshape((1, 1, args.dim))
- )
- wk_scale = wk_scale.repeat(1, args.block_size, 1).reshape(
- (args.index_head_dim, args.dim)
- )
- wq_a = wq_a.reshape((args.q_lora_rank, args.dim)).float() * wq_a_scale.float()
- wkv_a = (
- wkv_a.reshape((args.kv_lora_rank + args.qk_rope_head_dim, args.dim)).float()
- * wkv_a_scale.float()
- )
- wk = wk.reshape((args.index_head_dim, args.dim)).float() * wk_scale.float()
- # concatenate the weights
- weights_tensor: torch.Tensor = torch.cat([wq_a, wkv_a, wk], dim=0)
- assert weights_tensor.shape == (
- args.q_lora_rank + args.kv_lora_rank + args.qk_rope_head_dim + args.index_head_dim,
- args.dim,
- )
- # hard-coded scheduling: reshape to 140, 16, 7, 1024
- weights_tensor = weights_tensor.reshape(140, 16, 7, 1024)
- weights_tensor = weights_tensor.transpose(1, 2) # 140, 7, 16, 1024
- return x_rmsnorm_gamma, weights_tensor.to(torch.bfloat16).contiguous()
-
- def convert_to_decoupled(
- self, weights: list[torch.Tensor]
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Convert the weights to decoupled format.
-
- Args:
- weights: List of weights.
-
- Returns:
- Tuple of weights.
- """
- arch_name = self.model_args.arch_name
- wqkvia_and_scales = None
- with torch.inference_mode():
- x_rmsnorm_gamma, wq_a, wq_a_scale, wkv_a, wkv_a_scale, wk, wk_scale = weights
- # Ensure the scales are in bfloat16
- if arch_name == "deepseek_v3_2": # DS v3.2
- # Ensure the scales are in bfloat16 for DS v3.2
- wq_a_scale = wq_a_scale.to(torch.bfloat16)
- wkv_a_scale = wkv_a_scale.to(torch.bfloat16)
- wk_scale = wk_scale.to(torch.bfloat16)
- assert wq_a.shape == (1536, 7168)
- assert wq_a_scale.shape == (12, 56)
- assert wkv_a.shape == (576, 7168)
- assert wkv_a_scale.shape == (5, 56)
- assert wk.shape == (128, 7168)
- assert wk_scale.shape == (1, 56)
- wq_a = wq_a.reshape(96, 16, 7168)
- wq_a_scale = wq_a_scale.reshape(12, 1, 56).repeat(1, 8, 1).reshape(96, 56)
- wkv_a = wkv_a.reshape(36, 16, 7168)
- wkv_a_scale = wkv_a_scale.reshape(5, 1, 56).repeat(1, 8, 1).reshape(40, 56)
- wkv_a_scale = wkv_a_scale[:36]
-
- wk = wk.reshape(8, 16, 7168)
- wk_scale = wk_scale.reshape(1, 1, 56).repeat(1, 8, 1).reshape(8, 56)
- wqkvia = torch.cat([wq_a, wkv_a, wk], dim=0) # 140, 7168
- wqkvia_scale = torch.cat([wq_a_scale, wkv_a_scale, wk_scale], dim=0) # 140, 56
-
- wqkvia_0 = wqkvia[..., :2048]
- wqkvia_0_scale = wqkvia_scale[..., :16].contiguous().view(torch.float8_e4m3fn)
- wqkvia_1 = wqkvia[..., 2048:4096]
- wqkvia_1_scale = wqkvia_scale[..., 16:32].contiguous().view(torch.float8_e4m3fn)
- wqkvia_2 = wqkvia[..., 4096:6144]
- wqkvia_2_scale = wqkvia_scale[..., 32:48].contiguous().view(torch.float8_e4m3fn)
- wqkvia_3 = wqkvia[..., 6144:7168]
- wqkvia_3_scale = wqkvia_scale[..., 48:56].contiguous().view(torch.float8_e4m3fn)
-
- wqkvia_0 = wqkvia_0.reshape(140, 16, 64, 32).transpose(1, 2)
- wqkvia_0 = self._swizzle_qmma_16x32(wqkvia_0)
- wqkvia_0 = wqkvia_0.reshape(140, 16 * 2048)
-
- wqkvia_1 = wqkvia_1.reshape(140, 16, 64, 32).transpose(1, 2)
- wqkvia_1 = self._swizzle_qmma_16x32(wqkvia_1)
- wqkvia_1 = wqkvia_1.reshape(140, 16 * 2048)
-
- wqkvia_2 = wqkvia_2.reshape(140, 16, 64, 32).transpose(1, 2)
- wqkvia_2 = self._swizzle_qmma_16x32(wqkvia_2)
- wqkvia_2 = wqkvia_2.reshape(140, 16 * 2048)
-
- wqkvia_3 = wqkvia_3.reshape(140, 16, 32, 32).transpose(1, 2)
- wqkvia_3 = self._swizzle_qmma_16x32(wqkvia_3)
- wqkvia_3 = wqkvia_3.reshape(140, 16 * 1024)
- padding_scale0 = torch.zeros(
- (140, 48), dtype=torch.bfloat16, device=wq_a.device
- ).view(torch.float8_e4m3fn)
- padding_scale1 = torch.zeros(
- (140, 48), dtype=torch.bfloat16, device=wq_a.device
- ).view(torch.float8_e4m3fn)
- padding_scale2 = torch.zeros(
- (140, 48), dtype=torch.bfloat16, device=wq_a.device
- ).view(torch.float8_e4m3fn)
- padding_scale3 = torch.zeros(
- (140, 56), dtype=torch.bfloat16, device=wq_a.device
- ).view(torch.float8_e4m3fn)
- wqkvia_and_scales = torch.cat(
- [
- wqkvia_0,
- wqkvia_0_scale,
- padding_scale0,
- wqkvia_1,
- wqkvia_1_scale,
- padding_scale1,
- wqkvia_2,
- wqkvia_2_scale,
- padding_scale2,
- wqkvia_3,
- wqkvia_3_scale,
- padding_scale3,
- ],
- dim=1,
- )
- elif arch_name == "glm_5": # GLM5
- # Ensure the scales are in float32 for DS v3.2
- if wq_a_scale.dtype != torch.float32:
- # TODO: remove this after the source weights are converted to float32
- print(
- "Warning: RMSNormProjxWqkviaWeightsConverter: "
- + "wq_a_scale is not in float32, converting to float32."
- )
- wq_a_scale = wq_a_scale.to(torch.float32)
- wkv_a_scale = wkv_a_scale.to(torch.float32)
- wk_scale = wk_scale.to(torch.float32)
- # (2048 + 576 + 128, 6144)
- wqkvia = torch.cat([wq_a, wkv_a, wk], dim=0).reshape(86, 32, 6144)
- # (16+5+1, 48)
- wq_a_scale = wq_a_scale.reshape((16, 1, 48)).repeat(1, 4, 1).reshape(64, 48)
- wkv_a_scale = wkv_a_scale.reshape((5, 1, 48)).repeat(1, 4, 1).reshape(20, 48)[:18]
- wk_scale = wk_scale.reshape((1, 1, 48)).repeat(1, 4, 1).reshape(4, 48)
- wqkvia_scales = torch.cat([wq_a_scale, wkv_a_scale, wk_scale], dim=0) # (86, 48)
- wqkvia = wqkvia.reshape(86, 32, 6, 1024).transpose(1, 2).reshape(86, 6, 2, 16, 1024)
- wqkvia = wqkvia.reshape(86, 6, 2, 16, 32, 32).transpose(3, 4)
- wqkvia = self._swizzle_qmma_16x32(wqkvia).reshape(86, 6, 32 * 1024)
- wqkvia_scales = wqkvia_scales.reshape(86, 6, 8).view(torch.float8_e4m3fn)
- wqkvia_padding = torch.zeros(
- (86, 6, 128 - wqkvia_scales.shape[-1]),
- dtype=torch.float8_e4m3fn,
- device=wq_a.device,
- )
- wqkvia_and_scales = torch.cat([wqkvia, wqkvia_scales, wqkvia_padding], dim=-1)
- else:
- raise ValueError(f"Unsupported architecture: {arch_name}")
- assert wqkvia_and_scales is not None
- return x_rmsnorm_gamma.float(), wqkvia_and_scales.contiguous()
-
-
-@dataclass
-class RMSNormProjxWqkviaRefWeightsAlias:
- """Reference weights alias for RMSNormProjxWqkvia."""
-
- x_rmsnorm_gamma = "input_layernorm.weight"
- q_a_weights = "self_attn.q_a_proj.weight"
- q_a_scales = "self_attn.q_a_proj.weight_scale_inv"
- kv_a_with_mqa_weights = "self_attn.kv_a_proj_with_mqa.weight"
- kv_a_with_mqa_scales = "self_attn.kv_a_proj_with_mqa.weight_scale_inv"
- wk_weights = "self_attn.indexer.wk.weight"
- wk_scales = "self_attn.indexer.wk.weight_scale_inv"
-
- @property
- def ref_tensor_alias(self) -> list[str]:
- return [
- self.x_rmsnorm_gamma,
- self.q_a_weights,
- self.q_a_scales,
- self.kv_a_with_mqa_weights,
- self.kv_a_with_mqa_scales,
- self.wk_weights,
- self.wk_scales,
- ]
-
- def __call__(self) -> list[str]:
- return self.ref_tensor_alias
-
-
-@dataclass
-class RMSNormProjxWqkviaTilertWeightsAlias:
- """TileRT weights alias for RMSNormProjxWqkvia."""
-
- x_rmsnorm_gamma = "x_rmsnorm_gamma"
- q_a_weights = "q_a_weights"
- q_a_scales = "q_a_scales"
- kv_a_with_mqa_weights = "kv_a_with_mqa_weights"
- kv_a_with_mqa_scales = "kv_a_with_mqa_scales"
- wk_weights = "wk_weights"
- wk_scales = "wk_scales"
-
- @property
- def tilert_tensor_alias(self) -> list[str]:
- return [
- self.x_rmsnorm_gamma,
- self.q_a_weights,
- self.q_a_scales,
- self.kv_a_with_mqa_weights,
- self.kv_a_with_mqa_scales,
- self.wk_weights,
- self.wk_scales,
- ]
-
- def __call__(self) -> list[str]:
- return self.tilert_tensor_alias
-
-
-class RMSNormProjxWqkvia(TileRTModule):
- """RMSNormProjxWqkvia module"""
-
- def __init__(
- self,
- model_args: ModelArgs,
- num_devices: int,
- device_id: int,
- ref_weights_alias: RMSNormProjxWqkviaRefWeightsAlias | None = None,
- algorithm: RMSNormProjxWqkviaAlgorithm = RMSNormProjxWqkviaAlgorithm.GENERAL,
- ):
- super().__init__(
- self.__class__.__name__,
- model_args=model_args,
- num_devices=num_devices,
- device_id=device_id,
- )
-
- self.tilert_weights_alias = RMSNormProjxWqkviaTilertWeightsAlias()
- self.ref_weights_alias = (
- ref_weights_alias
- if ref_weights_alias is not None
- else RMSNormProjxWqkviaRefWeightsAlias()
- )
-
- self.arch_name = self.model_args.arch_name
- self.dim = self.model_args.dim
- self.q_lora_rank = self.model_args.q_lora_rank
- self.kv_lora_rank = self.model_args.kv_lora_rank
- self.qk_rope_head_dim = self.model_args.qk_rope_head_dim
- self.idx_head_dim = self.model_args.index_head_dim
- self.block_size = self.model_args.block_size
- self.eps = self.model_args.eps
- self.algorithm: RMSNormProjxWqkviaAlgorithm = algorithm
-
- # reference weights
- self.ref_norm_gamma: torch.Tensor | None = None
- self.ref_wq_a: torch.Tensor | None = None
- self.ref_wkv_a: torch.Tensor | None = None
- self.ref_wk: torch.Tensor | None = None
-
- # tilert weights
- self.tilert_norm_gamma: torch.Tensor | None = None
- self.tilert_wqkv_a: torch.Tensor | None = None
- # Legacy scale tensor for compatibility, to be removed in the future
- self.tilert_wqkv_a_scales = torch.zeros((130, 64), dtype=torch.bfloat16)
-
- # tilert vars
- self.x_rmsnorm_out: torch.Tensor | None = None
- self.q_out: torch.Tensor | None = None
- self.kv_out: torch.Tensor | None = None
- self.ki_out: torch.Tensor | None = None
- self.x_rmsnorm_quant_out: torch.Tensor | None = None
- self.x_rmsnorm_quant_scale_out: torch.Tensor | None = None
-
- self.profile_logs: torch.Tensor | None = None
- self.is_init = False
-
- # tilert_funcs
- self.rmsnorm_proj_func: Callable | None = None
- self.rmsnorm_func: Callable | None = None
- self.proj_func: Callable | None = None
-
- if self.arch_name == "deepseek_v3_2":
- self.rmsnorm_proj_func = rmsnorm_projx_wqkvia
- self.rmsnorm_func = rmsnorm_quant
- self.proj_func = projx_wqkvia
- elif self.arch_name == "glm_5":
- # Lazy import to avoid circular import
- self.rmsnorm_proj_func = None
- self.rmsnorm_func = rmsnorm_quant
- self.proj_func = projx_wqkvia
- else:
- raise ValueError(f"Unsupported architecture: {self.arch_name}")
-
- # tilert tensor aliases (3 output weight names for get_weights_list)
- self.tilert_tensor_alias: list[str] = [
- "x_rmsnorm_gamma",
- "qkv_wa_weights",
- "qkv_wa_scales",
- ]
-
- def get_weights_list(self) -> list[torch.Tensor]:
- """
- Get the weights list.
-
- Returns:
- List of weights.
- """
- assert self.algorithm is not None, "Algorithm is not set"
- if self.algorithm == RMSNormProjxWqkviaAlgorithm.GENERAL:
- return [self.tilert_norm_gamma, self.tilert_wqkv_a, self.tilert_wqkv_a_scales]
- return [self.tilert_norm_gamma, self.tilert_wqkv_a]
-
- def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
- """
- Device sharding.
-
- Args:
- input_layernorm_weight: Input layernorm weight.
- q_a_proj_weight: Q A proj weight.
- q_a_proj_weight_scale: Q A proj weight scale.
- kv_a_proj_weight: KV A proj weight.
- kv_a_proj_weight_scale: KV A proj weight scale.
- indexer_wk_weight: Indexer WK weight.
- indexer_wk_weight_scale: Indexer WK weight scale.
-
- Returns:
- Tuple of weights.
- """
- # repeat n times for device sharding
- # Using float to support both bfloat16 and float
- input_layernorm_weight = (
- weights_map[self.ref_weights_alias.x_rmsnorm_gamma][None, ...]
- .float()
- .repeat(self.num_devices, 1)
- )
- q_a_proj_weight = weights_map[self.ref_weights_alias.q_a_weights][None, ...].repeat(
- self.num_devices, 1, 1
- )
- q_a_proj_weight_scale = weights_map[self.ref_weights_alias.q_a_scales][None, ...].repeat(
- self.num_devices, 1, 1
- )
- kv_a_proj_weight = weights_map[self.ref_weights_alias.kv_a_with_mqa_weights][
- None, ...
- ].repeat(self.num_devices, 1, 1)
- kv_a_proj_weight_scale = weights_map[self.ref_weights_alias.kv_a_with_mqa_scales][
- None, ...
- ].repeat(self.num_devices, 1, 1)
- indexer_wk_weight = weights_map[self.ref_weights_alias.wk_weights][None, ...].repeat(
- self.num_devices, 1, 1
- )
- indexer_wk_weight_scale = weights_map[self.ref_weights_alias.wk_scales][None, ...].repeat(
- self.num_devices, 1, 1
- )
- return {
- self.tilert_weights_alias.x_rmsnorm_gamma: input_layernorm_weight,
- self.tilert_weights_alias.q_a_weights: q_a_proj_weight,
- self.tilert_weights_alias.q_a_scales: q_a_proj_weight_scale,
- self.tilert_weights_alias.kv_a_with_mqa_weights: kv_a_proj_weight,
- self.tilert_weights_alias.kv_a_with_mqa_scales: kv_a_proj_weight_scale,
- self.tilert_weights_alias.wk_weights: indexer_wk_weight,
- self.tilert_weights_alias.wk_scales: indexer_wk_weight_scale,
- }
-
- def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
- """
- Initialize the reference weights.
-
- Args:
- state_dict: State dictionary.
- """
- self.ref_norm_gamma = state_dict[self.ref_weights_alias()[0]]
- self.ref_wq_a = weight_dequant(
- state_dict[self.ref_weights_alias()[1]], state_dict[self.ref_weights_alias()[2]]
- )
- self.ref_wkv_a = weight_dequant(
- state_dict[self.ref_weights_alias()[3]], state_dict[self.ref_weights_alias()[4]]
- )
- self.ref_wk = weight_dequant(
- state_dict[self.ref_weights_alias()[5]], state_dict[self.ref_weights_alias()[6]]
- )
-
- assert self.ref_norm_gamma is not None
- assert self.ref_wq_a is not None
- assert self.ref_wkv_a is not None
- assert self.ref_wk is not None
-
- assert (
- self.ref_norm_gamma.shape[-1] == self.dim
- ), f"norm_gamma shape must be {self.dim}, but got {self.ref_norm_gamma.shape[-1]}"
- assert self.ref_wq_a.shape[-2] == self.q_lora_rank, (
- f"wq_a shape must be {self.q_lora_rank}, " + f"but got {self.ref_wq_a.shape[-2]}"
- )
- assert (
- self.ref_wq_a.shape[-1] == self.dim
- ), f"wq_a shape must be {self.dim}, but got {self.ref_wq_a.shape[-1]}"
- assert self.ref_wkv_a.shape[-2] == self.kv_lora_rank + self.qk_rope_head_dim, (
- f"wkv_a shape must be {self.kv_lora_rank + self.qk_rope_head_dim}, "
- + f"but got {self.ref_wkv_a.shape[-2]}"
- )
- assert (
- self.ref_wkv_a.shape[-1] == self.dim
- ), f"wkv_a shape must be {self.dim}, but got {self.ref_wkv_a.shape[-1]}"
- assert (
- self.ref_wk.shape[-2] == self.idx_head_dim
- ), f"wk shape must be {self.idx_head_dim}, but got {self.ref_wk.shape[-2]}"
- assert (
- self.ref_wk.shape[-1] == self.dim
- ), f"wk shape must be {self.dim}, but got {self.ref_wk.shape[-1]}"
-
- def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
- """
- Initialize the tilert weights.
-
- Args:
- state_dict: State dictionary.
- """
- assert self.algorithm is not None, "Algorithm is not set"
- self.tilert_norm_gamma, self.tilert_wqkv_a = RMSNormProjxWqkviaWeightsConverter(
- self.model_args, self.num_devices
- ).dispatch(self.algorithm, [state_dict[alias] for alias in self.tilert_weights_alias()])
-
- def init_tilert_vars(self, batch_size: int, seq_len: int) -> None:
- """
- Initialize the tilert variables.
-
- Args:
- batch_size: Batch size.
- seq_len: Sequence length.
- """
- self.q_out = torch.zeros((batch_size, seq_len, self.q_lora_rank), dtype=torch.bfloat16)
- self.kv_out = torch.zeros((batch_size, seq_len, self.kv_lora_rank), dtype=torch.bfloat16)
- self.ki_out = torch.zeros((batch_size, seq_len, self.idx_head_dim), dtype=torch.bfloat16)
- self.x_rmsnorm_out = torch.zeros((batch_size, seq_len, self.dim), dtype=torch.bfloat16)
- if self.algorithm == RMSNormProjxWqkviaAlgorithm.DECOUPLED:
- self.x_rmsnorm_quant_out = torch.zeros(
- (batch_size, seq_len, self.dim), dtype=torch.float8_e4m3fn
- )
- self.x_rmsnorm_quant_scale_out = torch.zeros(
- (batch_size, seq_len, self.dim // self.block_size), dtype=torch.float32
- )
- self.profile_logs = get_profile_log_tensor()
- self.is_init = True
-
- def init_random_weights(self) -> None:
- """
- Initialize the random weights.
-
- Returns:
- None
- """
- q_scale_dim = self.q_lora_rank // self.block_size
- kv_scale_dim = (self.kv_lora_rank + self.qk_rope_head_dim) // self.block_size + 1
- wk_scale_dim = self.idx_head_dim // self.block_size
- dim_scale_dim = self.dim // self.block_size
- scale_dtype = torch.float32 if self.arch_name == "glm_5" else torch.bfloat16
-
- tensor_list = [
- torch.randn(self.dim, dtype=torch.float32),
- torch.randn(self.q_lora_rank, self.dim, dtype=torch.bfloat16).to(torch.float8_e4m3fn),
- torch.randn(q_scale_dim, dim_scale_dim, dtype=scale_dtype),
- torch.randn(
- self.kv_lora_rank + self.qk_rope_head_dim, self.dim, dtype=torch.bfloat16
- ).to(torch.float8_e4m3fn),
- torch.randn(kv_scale_dim, dim_scale_dim, dtype=scale_dtype),
- torch.randn(self.idx_head_dim, self.dim, dtype=torch.bfloat16).to(torch.float8_e4m3fn),
- torch.randn(wk_scale_dim, dim_scale_dim, dtype=scale_dtype),
- ]
- ref_state_dict = dict(zip(self.ref_weights_alias(), tensor_list))
- self.init_reference_weights(ref_state_dict)
- self.init_tilert_weights(
- {_k: _v[self.device_id] for _k, _v in self.device_sharding(ref_state_dict).items()}
- )
-
- def golden_forward(
- self,
- x: torch.Tensor,
- pe_cache: torch.Tensor,
- start_pos: int,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
-
- assert self.ref_norm_gamma is not None
- assert self.ref_wq_a is not None
- assert self.ref_wkv_a is not None
- assert self.ref_wk is not None
-
- x_rmsnorm_out = torch.nn.functional.rms_norm(
- x.float(), [x.size(-1)], self.ref_norm_gamma, self.eps
- )
-
- q_out = torch.matmul(x_rmsnorm_out.float(), self.ref_wq_a.transpose(0, 1).float())
- kv_out = torch.matmul(x_rmsnorm_out.float(), self.ref_wkv_a.transpose(0, 1).float())
- kv_out, k_pe = torch.split(kv_out, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
- bsz = k_pe.shape[0]
- seq_len = k_pe.shape[1]
- pe_cache[:bsz, start_pos : start_pos + seq_len].copy_(k_pe.to(torch.bfloat16))
- ki_out = torch.matmul(x_rmsnorm_out.float(), self.ref_wk.transpose(0, 1).float())
- return (
- x_rmsnorm_out.to(torch.bfloat16),
- q_out.to(torch.bfloat16),
- kv_out.to(torch.bfloat16),
- ki_out.to(torch.bfloat16),
- )
-
- def tilert_forward(
- self,
- x: torch.Tensor,
- pe_cache: torch.Tensor,
- start_pos: int,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
- if self.algorithm == RMSNormProjxWqkviaAlgorithm.GENERAL:
- assert self.rmsnorm_proj_func is not None
- self.rmsnorm_proj_func(
- x.to(torch.bfloat16),
- self.tilert_wqkv_a,
- self.tilert_wqkv_a_scales,
- self.tilert_norm_gamma,
- torch.tensor([start_pos], dtype=torch.int32, device=x.device),
- self.q_out,
- self.kv_out,
- pe_cache,
- self.ki_out,
- self.x_rmsnorm_out,
- self.profile_logs,
- )
- elif self.algorithm == RMSNormProjxWqkviaAlgorithm.DECOUPLED:
- assert self.rmsnorm_func is not None
- assert self.proj_func is not None
- self.rmsnorm_func(
- x.to(torch.bfloat16),
- self.tilert_norm_gamma,
- self.x_rmsnorm_out,
- self.x_rmsnorm_quant_out,
- self.x_rmsnorm_quant_scale_out,
- self.profile_logs,
- )
- self.proj_func(
- self.x_rmsnorm_quant_out,
- self.x_rmsnorm_quant_scale_out,
- self.tilert_wqkv_a,
- torch.tensor([start_pos], dtype=torch.int32, device=x.device),
- self.q_out,
- self.kv_out,
- pe_cache,
- self.ki_out,
- self.profile_logs,
- )
- else:
- raise ValueError(f"Unsupported algorithm: {self.algorithm}")
-
- if self.flag_enable_profiling_log:
- parse_profile_log_tensor(
- self.profile_logs, self.get_profile_log_path(), [(self.op_name, 0.0)]
- )
- return self.x_rmsnorm_out, self.q_out, self.kv_out, self.ki_out
-
- def __call__(
- self,
- x: torch.Tensor,
- pe_cache: torch.Tensor,
- start_pos: int,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
- return self.golden_forward(x, pe_cache, start_pos)
diff --git a/python/models/deepseek_v3_2/ops/top1_allreduce.py b/python/models/deepseek_v3_2/ops/top1_allreduce.py
deleted file mode 100644
index 1d500e3..0000000
--- a/python/models/deepseek_v3_2/ops/top1_allreduce.py
+++ /dev/null
@@ -1,25 +0,0 @@
-"""Top1 Allreduce operation"""
-
-import torch
-
-__all__ = [
- "top1_allreduce",
-]
-
-
-def top1_allreduce(
- logits: torch.Tensor,
- flag: int,
- index_out: torch.Tensor,
- profile_logs: torch.Tensor,
-) -> None:
- """
- Define the Top1 Allreduce operation.
-
- Args:
- logits: Input tensor.
- flag: Flag.
- index_out: Output tensor.
- profile_logs: Profile logs tensor.
- """
- torch.ops.tilert.top1_allreduce_op(logits, flag, index_out, profile_logs)
diff --git a/python/models/deepseek_v3_2/ops/top_p.py b/python/models/deepseek_v3_2/ops/top_p.py
deleted file mode 100644
index 4394c2a..0000000
--- a/python/models/deepseek_v3_2/ops/top_p.py
+++ /dev/null
@@ -1,68 +0,0 @@
-"""TopP operation module."""
-
-import torch
-
-__all__ = [
- "top_p",
-]
-
-
-def top_p(
- logits: torch.Tensor,
- in_indices: torch.Tensor,
- sampling_seed: torch.Tensor,
- positions: torch.Tensor,
- is_verify_mode: bool,
- temperature: float,
- top_p: float,
- top_k: int,
- flag: int,
- indices: torch.Tensor,
- scores: torch.Tensor,
- debug_tensor: torch.Tensor,
- profile_logs: torch.Tensor,
-) -> None:
- """top_p operation.
-
- Args:
- logits (Tensor): The logits tensor.
- in_indices (Tensor): The tensor containing input indices.
- sampling_seed (Tensor): Random seeds for each sequence position.
- positions (Tensor): Token positions for each sequence element.
- is_verify_mode (bool): A flag indicating if verify mode is enabled in MTP. When set to
- `True`, the `in_indices` will be checked to check if it is in
- the top-k values.
- temperature (float): The temperature parameter, used for scaling logits in softmax
- calculations.
- top_p (float): The top-p value, used for nucleus sampling to restrict the selection to the
- smallest set of tokens whose cumulative probability is greater than or equal
- to `top_p`.
- top_k (int): The number of top-k values that occupy the top-p probability mass
- during sampling.
- flag (int): Used in all reduction.
- indices (Tensor): The tensor containing output indices.
- scores (Tensor): The tensor containing corresponding scores for the indices.
- profile_logs (Tensor): A tensor for storing profiling log data during execution in MTP.
- """
- dim = logits.shape[-1]
- if dim == 19360:
- call_func = torch.ops.tilert.top_p_glm5_op
- elif dim == 16160:
- call_func = torch.ops.tilert.top_p_op
- else:
- raise ValueError(f"Unsupported dimension: {dim}")
- call_func(
- logits,
- in_indices,
- sampling_seed,
- positions,
- is_verify_mode,
- temperature,
- top_p,
- top_k,
- flag,
- indices,
- scores,
- debug_tensor,
- profile_logs,
- )
diff --git a/python/models/deepseek_v3_2/ops/up_gate_silu.py b/python/models/deepseek_v3_2/ops/up_gate_silu.py
deleted file mode 100644
index 2f214c0..0000000
--- a/python/models/deepseek_v3_2/ops/up_gate_silu.py
+++ /dev/null
@@ -1,24 +0,0 @@
-"""UpGateSiLU operation module."""
-
-import torch
-
-__all__ = [
- "up_gate_silu",
-]
-
-
-def up_gate_silu(
- hidden_in: torch.Tensor,
- expert_indices_in: torch.Tensor,
- experts_weights_in: torch.Tensor,
- hidden_out: torch.Tensor,
- profile_logs: torch.Tensor,
-) -> None:
- """Up Gate SiLU operation."""
- torch.ops.tilert.up_gate_silu_op(
- hidden_in,
- expert_indices_in,
- experts_weights_in,
- hidden_out,
- profile_logs,
- )
diff --git a/python/models/deepseek_v3_2/refs/kernel.py b/python/models/deepseek_v3_2/refs/kernel.py
deleted file mode 100644
index eb5e274..0000000
--- a/python/models/deepseek_v3_2/refs/kernel.py
+++ /dev/null
@@ -1,354 +0,0 @@
-try:
- import tilelang
- import tilelang.language as T
-except ImportError:
- raise ImportError("Cannot import tilelang, please install tilelang.") from None
-
-
-import torch
-import triton
-import triton.language as tl
-
-__all__ = [
- "weight_dequant",
- "act_quant",
- "fp8_gemm",
- "fp8_index",
-]
-
-tilelang.set_log_level("WARNING")
-
-pass_configs = {
- tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
- tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
- # tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True,
-}
-
-FP8 = "float8_e4m3"
-BF16 = "bfloat16"
-FP32 = "float32"
-
-
-def fast_log2_ceil(x): # type: ignore
- bits_x = T.reinterpret("uint32", x)
- exp_x = (bits_x >> 23) & 0xFF
- man_bits = bits_x & ((1 << 23) - 1)
- return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0))
-
-
-def fast_pow2(x): # type: ignore
- bits_x = (x + 127) << 23
- return T.reinterpret("float32", bits_x)
-
-
-def fast_round_scale(amax, fp8_max_inv): # type: ignore
- return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))
-
-
-@triton.jit
-def weight_dequant_kernel( # type: ignore
- x_ptr,
- s_ptr,
- y_ptr,
- M_Size: tl.constexpr,
- N_Size: tl.constexpr,
- BLOCK_SIZE: tl.constexpr,
-) -> None:
- """
- Weight dequantization kernel.
-
- Dequantizes weights using the provided scaling factors and stores the
- result.
-
- Args:
- x_ptr (tl.pointer): Pointer to the quantized weights.
- s_ptr (tl.pointer): Pointer to the scaling factors.
- y_ptr (tl.pointer): Pointer to the output buffer for dequantized
- weights.
- M (int): Number of rows in the weight matrix.
- N (int): Number of columns in the weight matrix.
- BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
-
- Returns:
- None
- """
- pid_m = tl.program_id(axis=0)
- pid_n = tl.program_id(axis=1)
- n_size = tl.cdiv(N_Size, BLOCK_SIZE)
- offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
- offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
- offs = offs_m[:, None] * N_Size + offs_n[None, :]
- mask = (offs_m[:, None] < M_Size) & (offs_n[None, :] < N_Size)
- x_in = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
- s_in = tl.load(s_ptr + pid_m * n_size + pid_n)
- y_out = x_in * s_in
- tl.store(y_ptr + offs, y_out, mask=mask)
-
-
-def weight_dequant(x_in: torch.Tensor, s_in: torch.Tensor, block_size: int = 128) -> torch.Tensor:
- """
- Dequantizes the given weight tensor using the provided scale tensor.
-
- Args:
- x_in (torch.Tensor): The quantized weight tensor of shape (M, N).
- s_in (torch.Tensor): The scale tensor of shape (M//block_size,
- N//block_size).
- block_size (int, optional): The block size to use for dequantization.
- Defaults to 128.
-
- Returns:
- torch.Tensor: The dequantized weight tensor of the same shape as `x`.
-
- Raises:
- AssertionError: If `x` or `s` are not contiguous or if their dimensions
- are not 2.
- """
- assert x_in.is_contiguous() and s_in.is_contiguous(), "Input tensors must be contiguous"
- assert x_in.dim() == 2 and s_in.dim() == 2, "Input tensors must have 2 dimensions"
- M_Size, N_Size = x_in.size()
- y_out = torch.empty_like(x_in, dtype=torch.get_default_dtype())
- grid = lambda meta: ( # noqa: E731
- triton.cdiv(M_Size, meta["BLOCK_SIZE"]),
- triton.cdiv(N_Size, meta["BLOCK_SIZE"]),
- )
- weight_dequant_kernel[grid](x_in, s_in, y_out, M_Size, N_Size, BLOCK_SIZE=block_size)
- return y_out
-
-
-@tilelang.jit(pass_configs=pass_configs)
-def act_quant_kernel( # type: ignore
- N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False # type: ignore
-): # type: ignore
- M = T.symbolic("M")
- fp8_min = -448.0
- fp8_max = 448.0
- fp8_max_inv = 1 / fp8_max
- num_stages = 0 if round_scale else 2
- blk_m = 32
- group_size = 128
-
- @T.prim_func
- def act_quant_kernel_( # type: ignore
- X: T.Tensor[(M, N), in_dtype],
- Y: T.Tensor[(M, N), out_dtype],
- S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
- ): # type: ignore
- with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
- pid_m,
- pid_n,
- ):
- x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
- x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
- amax_local = T.alloc_fragment((blk_m,), scale_dtype)
- s_local = T.alloc_fragment((blk_m,), scale_dtype)
- y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
- y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
-
- for _ in T.Pipelined(1, num_stages=num_stages):
- T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
- T.copy(x_shared, x_local)
- T.reduce_absmax(x_local, amax_local, dim=1)
- for i in T.Parallel(blk_m):
- amax_local[i] = T.max(amax_local[i], 1e-4)
- if round_scale:
- s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv)
- else:
- s_local[i] = amax_local[i] * fp8_max_inv
- for i, j in T.Parallel(blk_m, group_size):
- y_local[i, j] = T.clamp(x_local[i, j] / s_local[i], fp8_min, fp8_max)
- for i in T.Parallel(blk_m):
- S[pid_m * blk_m + i, pid_n] = s_local[i]
- T.copy(y_local, y_shared)
- T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
-
- return act_quant_kernel_
-
-
-def act_quant(
- x: torch.Tensor, block_size: int = 128, scale_fmt: str | None = None
-) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Quantizes the input tensor `x` using block-wise quantization.
-
- Args:
- x (torch.Tensor): The input tensor to be quantized.
- Must be contiguous and its last dimension size must be divisible by `block_size`.
- block_size (int, optional): The size of the blocks to be used for quantization.
- Default is 128.
- scale_fmt (Optional[str], optional): The format of the scale. Default is None.
- Returns:
- Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- - The quantized tensor with dtype `torch.float8_e4m3fn`.
- - A tensor of scaling factors with dtype `torch.float32`.
- """
- assert x.is_contiguous(), "Input tensor must be contiguous"
- assert (
- x.size(-1) % block_size == 0
- ), f"Last dimension size must be divisible by block_size (block_size={block_size})"
- N = x.size(-1)
- y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
- s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32)
- kernel = act_quant_kernel(N, round_scale=scale_fmt is not None)
- kernel(x.view(-1, N), y.view(-1, N), s.view(-1, N // block_size))
- return y, s
-
-
-@tilelang.jit(pass_configs=pass_configs)
-def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype="float32"): # type: ignore
- assert out_dtype in [BF16, "float32"]
-
- M = T.symbolic("M")
- group_size = 128
- block_M = 32
- block_N = 128
- block_K = 128
-
- @T.prim_func
- def fp8_gemm_kernel_( # type: ignore
- A: T.Tensor[(M, K), FP8],
- B: T.Tensor[(N, K), FP8],
- C: T.Tensor[(M, N), out_dtype],
- scales_a: T.Tensor[(M, T.ceildiv(K, group_size)), FP32],
- scales_b: T.Tensor[(T.ceildiv(N, group_size), T.ceildiv(K, group_size)), FP32],
- ): # type: ignore
- with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
- bx,
- by,
- ):
- A_shared = T.alloc_shared((block_M, block_K), FP8)
- B_shared = T.alloc_shared((block_N, block_K), FP8)
- C_shared = T.alloc_shared((block_M, block_N), out_dtype)
- Scale_C_shared = T.alloc_shared((block_M), FP32)
- C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
- C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype)
-
- # Improve L2 Cache
- T.use_swizzle(panel_size=10)
-
- T.clear(C_local)
- T.clear(C_local_accum)
- K_iters = T.ceildiv(K, block_K)
- for k in T.Pipelined(K_iters, num_stages=4):
- # Load A into shared memory
- T.copy(A[by * block_M, k * block_K], A_shared)
- # Load B into shared memory
- T.copy(B[bx * block_N, k * block_K], B_shared)
- # Load scale into shared memory
- Scale_B = scales_b[bx * block_N // group_size, k]
- for i in T.Parallel(block_M):
- Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B
-
- T.gemm(A_shared, B_shared, C_local, transpose_B=True)
- # Promote to enable 2xAcc
- for i, j in T.Parallel(block_M, block_N):
- C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
- T.clear(C_local)
- # TMA store
- T.copy(C_local_accum, C_shared)
- T.copy(C_shared, C[by * block_M, bx * block_N])
-
- return fp8_gemm_kernel_
-
-
-def fp8_gemm(
- a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor
-) -> torch.Tensor:
- """
- Perform a matrix multiplication using FP8 precision.
-
- Args:
- a (torch.Tensor): The first input matrix, must be contiguous.
- a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
- b (torch.Tensor): The second input matrix, must be contiguous.
- b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
-
- Returns:
- torch.Tensor: The result of the matrix multiplication.
- """
- assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
- assert a_s.is_contiguous() and b_s.is_contiguous(), "Scaling factor tensors must be contiguous"
- K = a.size(-1)
- M = a.numel() // K
- N = b.size(0)
- c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
- kernel = fp8_gemm_kernel(N, K)
- kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s)
- return c
-
-
-@tilelang.jit(out_idx=[4], pass_configs=pass_configs)
-def fp8_index_kernel(h: int, d: int): # type: ignore
- b = T.symbolic("b")
- m = T.symbolic("m")
- n = T.symbolic("n")
-
- blk_n1 = 512
- blk_n2 = 128
-
- @T.prim_func
- def fp8_index_kernel_(
- q: T.Tensor[(b, m, h, d), FP8],
- q_s: T.Tensor[(b, m, h), FP32],
- k: T.Tensor[(b, n, d), FP8],
- k_s: T.Tensor[(b, n), FP32],
- o: T.Tensor[(b, m, n), FP32],
- ) -> None:
- with T.Kernel(b, m, T.ceildiv(n, blk_n1)) as (i_b, i_m, i1_n):
- q_smem = T.alloc_shared((h, d), FP8)
- T.copy(q[i_b, i_m, 0, 0], q_smem)
-
- q_s_frag = T.alloc_fragment(h, FP32)
- T.copy(q_s[i_b, i_m, 0], q_s_frag)
-
- for i2_n in T.Pipelined(blk_n1 // blk_n2, num_stages=2):
- k_smem = T.alloc_shared((blk_n2, d), FP8)
- T.copy(k[i_b, i1_n * blk_n1 + i2_n * blk_n2, 0], k_smem)
-
- k_s_frag = T.alloc_fragment(blk_n2, FP32)
- T.copy(k_s[i_b, i1_n * blk_n1 + i2_n * blk_n2], k_s_frag)
-
- logits = T.alloc_fragment((blk_n2, h), FP32)
- T.gemm(
- k_smem,
- q_smem,
- logits,
- transpose_A=False,
- transpose_B=True,
- clear_accum=True,
- )
-
- for i_h, i3_n in T.Parallel(h, blk_n2):
- logits[i3_n, i_h] = T.max(logits[i3_n, i_h], 0) * q_s_frag[i_h]
-
- logits_sum = T.alloc_fragment(blk_n2, FP32)
- T.reduce_sum(logits, logits_sum, dim=1)
-
- for i3_n in T.Parallel(blk_n2):
- logits_sum[i3_n] *= k_s_frag[i3_n]
-
- T.copy(logits_sum, o[i_b, i_m, i1_n * blk_n1 + i2_n * blk_n2])
-
- return fp8_index_kernel_
-
-
-def fp8_index(
- q: torch.Tensor,
- q_s: torch.Tensor,
- k: torch.Tensor,
- k_s: torch.Tensor,
-) -> torch.Tensor:
- """
- Perform index score using FP8 precision.
-
- Args:
- q (torch.Tensor): The Q tensor, must be contiguous.
- q_s (torch.Tensor): The scaling factor for Q (float), must be contiguous.
- k (torch.Tensor): The K tensor, must be contiguous.
- k_s (torch.Tensor): The scaling factor for K (e8m0 here), must be contiguous.
-
- fp8 q @ fp8 k -> fp32 logits
- relu(fp32 logits) * q_s (weights) -> fp32 logits
- fp32 logits -> fp32 logits_sum
- fp32 logits_sum * k_s (e8m0) -> fp32 index_score
- """
- return fp8_index_kernel(q.shape[2], q.shape[3])(q, q_s, k, k_s)
diff --git a/python/models/glm_5/params.py b/python/models/glm_5/params.py
deleted file mode 100644
index 2721229..0000000
--- a/python/models/glm_5/params.py
+++ /dev/null
@@ -1 +0,0 @@
-"""GLM5 parameters and initializers."""
diff --git a/python/profiler/__init__.py b/python/profiler/__init__.py
deleted file mode 100644
index e9b1cf9..0000000
--- a/python/profiler/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-"""Profiler utilities for TileRT."""
diff --git a/python/profiler/utils.py b/python/profiler/utils.py
deleted file mode 100644
index ecd83f6..0000000
--- a/python/profiler/utils.py
+++ /dev/null
@@ -1,477 +0,0 @@
-import os
-from dataclasses import dataclass
-from typing import Any
-
-import numpy as np
-import torch
-
-from tilert.utils import SLICES_FOR_TILERT_OP
-
-# Worker names used by ExecPlanDescriptor (previously from scheduling.plan_v0)
-WORKER_NAMES = [
- "Init",
- "Prefetch",
- "Compute",
- "ExtraTask1/SyncIo",
- "ExtraTask2/IoP0",
- "ExtraTask3/IoP2",
- "ExtraTask4",
- "ExtraTask5",
-]
-
-try:
- from openpyxl import Workbook
- from openpyxl.cell import Cell
- from openpyxl.styles import Alignment, Border, PatternFill, Side
- from openpyxl.styles.colors import COLOR_INDEX
- from openpyxl.worksheet.worksheet import Worksheet
-except ImportError:
- print("openpyxl is not installed, profile logs will not be visualized")
- Workbook = None
-
-
-__all__ = [
- "ExcelStyleConfigs",
- "ExecPlanDescriptor",
- "WorkerBookVisualizer",
- "visualize_profile_logs",
- "parse_profile_log_tensor",
- "parse_op_time",
-]
-
-
-@dataclass
-class ExcelStyleConfigs:
- """Excel style configurations."""
-
- # 2 col * 3 stream
- cols_per_worker: int = 6
- ns_per_tick: int = 1000
-
-
-@dataclass
-class ExecPlanDescriptor:
- """Exec plan descriptor."""
-
- workers_def: list
- op_lists: list
-
-
-class WorkerBookVisualizer:
- """Sheet visualizer."""
-
- def __init__(self, exec_plan_desc: ExecPlanDescriptor):
- self.exec_plan_desc = exec_plan_desc
-
- self.wb = Workbook()
- self.wb.remove(self.wb.active)
-
- # Excel configs
- self.style_configs = ExcelStyleConfigs()
-
- self.op_cols_splits = 3
-
- self.time_bar_cols = 1
- self.op_stat_bar_cols = 6
-
- workers_num = len(self.exec_plan_desc.workers_def)
- self.op_vis_bar_cols = workers_num * self.style_configs.cols_per_worker
- assert self.op_stat_bar_cols % self.op_cols_splits == 0
-
- @property
- def time_bar_next_col(self) -> int:
- return self.time_bar_cols + 1
-
- @property
- def op_stat_bar_next_col(self) -> int:
- return self.time_bar_next_col + self.op_stat_bar_cols
-
- @property
- def op_vis_bar_next_col(self) -> int:
- return self.op_stat_bar_next_col + self.op_vis_bar_cols
-
- @staticmethod
- def add_region_cell(
- ws: Worksheet,
- value: str,
- start_row: int,
- start_col: int,
- row_size: int = 1,
- col_size: int = 1,
- color_offset: int = -1,
- ) -> Cell:
- cell = ws.cell(row=start_row, column=start_col, value=value)
- cell.alignment = Alignment(horizontal="center", vertical="center", wrap_text=True)
- if color_offset >= 0:
- cell.fill = PatternFill(
- start_color=COLOR_INDEX[50 + color_offset],
- end_color=COLOR_INDEX[50 + color_offset],
- fill_type="solid",
- )
- ws.merge_cells(
- start_row=start_row,
- start_column=start_col,
- end_row=start_row + row_size - 1,
- end_column=start_col + col_size - 1,
- )
- return cell
-
- def init_layout(self, ws: Worksheet) -> None:
- workers_name = self.exec_plan_desc.workers_def
- worker_cols = self.style_configs.cols_per_worker
-
- self.add_region_cell(ws, "Op Info", 1, self.time_bar_next_col, 1, self.op_stat_bar_cols)
-
- for worker_id, worker_name in enumerate(workers_name):
- start_col = worker_cols * worker_id + self.op_stat_bar_next_col
- self.add_region_cell(ws, worker_name, 1, start_col, 1, worker_cols)
-
- def _parse_inst_info(
- self, insts_info: list[tuple[str, float, int] | tuple[str, float] | str], op_idx: int
- ) -> tuple[str, float, int]:
- inst_info = insts_info[op_idx]
- if isinstance(inst_info, str):
- op_name, op_cost = inst_info, 0.0
- op_stream = op_idx % self.op_cols_splits
- elif len(inst_info) == 2:
- op_name, op_cost = inst_info
- op_stream = op_idx % self.op_cols_splits
- elif len(inst_info) == 3:
- op_name, op_cost, op_stream = inst_info
- else:
- raise TypeError("Invalid inst_info format")
- return op_name, op_cost, op_stream
-
- def add_region_cell_by_time(
- self,
- ws: Worksheet,
- op_show_info: str,
- start_time: float,
- end_time: float,
- op_col_start: int,
- op_col_size: int,
- ns_tick: int,
- color_offset: int = -1,
- ) -> Cell:
- op_start_row_idx = np.round(start_time / ns_tick).astype(np.int32) + 2
- op_end_row_idx = np.round(end_time / ns_tick).astype(np.int32) + 2
- op_end_row_idx = max(op_end_row_idx, op_start_row_idx)
- return self.add_region_cell(
- ws,
- op_show_info,
- op_start_row_idx,
- op_col_start,
- max(op_end_row_idx - op_start_row_idx, 1),
- op_col_size,
- color_offset,
- )
-
- def timeline_visual_region(
- self,
- ws: Worksheet,
- profile_logs: np.ndarray,
- insts_info: list[tuple[str, float, int] | tuple[str, float] | str],
- ignore_prefilling: bool = True,
- ) -> None:
- ns_tick = self.style_configs.ns_per_tick
- self.init_layout(ws)
-
- total_end_time = 0
- for op_idx, op_log in enumerate(profile_logs):
- op_name, op_cost, op_stream = self._parse_inst_info(insts_info, op_idx)
-
- if op_stream >= self.op_cols_splits:
- print(f"stream_id (aka col_id) must < {self.op_cols_splits}")
- raise ValueError
-
- valid_mask: np.ndarray = op_log >= 0
- if ignore_prefilling:
- valid_mask[2:4] = False
-
- if np.count_nonzero(valid_mask) == 0:
- continue
-
- op_start_time = np.min(op_log, where=valid_mask, initial=np.inf)
- op_end_time = np.max(op_log, where=valid_mask, initial=-np.inf)
- total_end_time = max(total_end_time, op_end_time)
-
- op_cost_theory = op_cost / 1000
- op_cost_actual = (op_end_time - op_start_time) / 1000
- op_bw_utils = f"{op_cost_theory / op_cost_actual * 100:.2f}"
-
- op_show_info = (
- f"{op_name}\n"
- + f"BW Util: {op_bw_utils}%\n"
- + f"Actual: {op_cost_actual:.2f}us\n"
- + f"Theoretical: {op_cost_theory:.2f}us\n"
- + f"Start Time: {op_start_time / 1000:.2f}us\n"
- + f"End Time: {op_end_time / 1000:.2f}us"
- )
- op_col_size = self.op_stat_bar_cols // self.op_cols_splits
- op_col_start = self.time_bar_next_col + op_stream * op_col_size
- self.add_region_cell_by_time(
- ws,
- op_show_info,
- op_start_time,
- op_end_time,
- op_col_start,
- op_col_size,
- ns_tick,
- )
-
- for queue_idx, (start_time, end_time) in enumerate(zip(op_log[::2], op_log[1::2])):
- if start_time < 0 or end_time < 0:
- continue
- task_dur = (end_time - start_time) / 1000
- task_bw_utils = f"{min(100, op_cost_theory / task_dur * 100):.2f}"
- task_show_info = (
- f"{op_name}\n"
- + f"Dur: {task_dur:.2f}us\n"
- + f"BW Util. {task_bw_utils}%:\n"
- + f"Start: {start_time / 1000:.2f}us\n"
- + f"End: {end_time / 1000:.2f}us"
- )
- task_col_size = self.style_configs.cols_per_worker // self.op_cols_splits
- task_col_start = (
- self.op_stat_bar_next_col
- + queue_idx * self.style_configs.cols_per_worker
- + op_stream * task_col_size
- )
- cell = self.add_region_cell_by_time(
- ws,
- task_show_info,
- start_time,
- end_time,
- task_col_start,
- task_col_size,
- ns_tick,
- queue_idx,
- )
- cell.border = Border(
- left=Side(style="thin"),
- right=Side(style="thin"),
- top=Side(style="thin"),
- bottom=Side(style="thin"),
- )
-
- for dur_idx, dur_start in enumerate(range(0, int(total_end_time), ns_tick)):
- ws.cell(row=dur_idx + 2, column=1, value=f"{(dur_start + ns_tick) / 1000:.2f}")
-
- def brief_table_region(
- self,
- ws: Worksheet,
- profile_logs: np.ndarray,
- insts_info: list[tuple[str, float, int] | tuple[str, float] | str],
- ) -> None:
- for op_idx, op_log in enumerate(profile_logs):
- op_name, _, _ = self._parse_inst_info(insts_info, op_idx)
-
- ws.cell(row=op_idx + 2, column=self.op_vis_bar_next_col, value=op_name)
-
- for queue_idx, (start_time, end_time) in enumerate(zip(op_log[::2], op_log[1::2])):
- if start_time < 0 or end_time < 0:
- continue
- task_dur = (end_time - start_time) / 1000
- ws.cell(
- row=op_idx + 2, column=self.op_vis_bar_next_col + queue_idx + 1, value=task_dur
- )
-
- def add_sheet(self, profile_logs: np.ndarray, sheet_name: str) -> "WorkerBookVisualizer":
- """Add a sheet to the workbook."""
- wb = self.wb
- insts_info = self.exec_plan_desc.op_lists
-
- ws = wb.create_sheet(sheet_name)
- self.timeline_visual_region(ws, profile_logs, insts_info)
- self.brief_table_region(ws, profile_logs, insts_info)
-
- return self
-
- def add_sm_brief_sheet(
- self, profile_logs: np.ndarray, sheet_name: str
- ) -> "WorkerBookVisualizer":
- """Add a brief sheet to workbook which contains min/max start/end and duration among SMs"""
- wb = self.wb
- insts_info = self.exec_plan_desc.op_lists
- ws = wb.create_sheet(sheet_name)
-
- profile_logs = np.transpose(profile_logs, (1, 0, 2))
-
- # 1. init layout
- workers_name = self.exec_plan_desc.workers_def
- worker_metric_def = [
- "min_start",
- "max_end",
- "min_dur",
- "max_dur",
- "mean_dur",
- "std_dur",
- ]
-
- worker_cols = len(worker_metric_def)
-
- self.add_region_cell(ws, "Op Info", 1, self.time_bar_next_col, 1, self.op_stat_bar_cols)
-
- for worker_id, worker_name in enumerate(workers_name):
- start_col = worker_cols * worker_id + self.op_stat_bar_next_col
- self.add_region_cell(ws, worker_name, 1, start_col, 1, worker_cols)
- for metric_id, metric_name in enumerate(worker_metric_def):
- start_col_metric = start_col + metric_id
- self.add_region_cell(ws, metric_name, 2, start_col_metric, 1, 1)
-
- # 2. calc metrics
- # profile_logs: (num_ops, num_sm, num_task*2)
- for op_idx, op_profile_log in enumerate(profile_logs):
- valid_mask = (op_profile_log >= 0) & (op_profile_log < 1e9)
- # skip if this op is fully invalid
- if not np.any(valid_mask):
- continue
-
- op_name, _, _ = self._parse_inst_info(insts_info, op_idx)
- self.add_region_cell(ws, op_name, op_idx + 3, self.time_bar_next_col, 1, 2)
-
- for queue_idx in range(op_profile_log.shape[1] // 2):
- starts = op_profile_log[:, queue_idx * 2]
- ends = op_profile_log[:, queue_idx * 2 + 1]
-
- valid_mask = (
- (starts >= 0) & (starts < 1e9) & (ends >= 0) & (ends < 1e9) & (starts <= ends)
- )
-
- valid_starts = starts[valid_mask] / 1000
- valid_ends = ends[valid_mask] / 1000
-
- if len(valid_starts) == 0:
- continue
-
- min_start = np.min(valid_starts)
- max_end = np.max(valid_ends)
- durations = valid_ends - valid_starts
-
- metrics_values = [
- min_start,
- max_end,
- np.min(durations),
- np.max(durations),
- np.mean(durations),
- np.std(durations),
- ]
-
- # row_idx start from 3, because {1: work_name, 2: metric_name}
- # col_idx start from worker::start_col
- start_row = op_idx + 3
- start_col = worker_cols * queue_idx + self.op_stat_bar_next_col
- color_offset = queue_idx
-
- for i, value in enumerate(metrics_values):
- # color mean and std dev
- cell_color = color_offset if i >= 4 else -1
- self.add_region_cell(ws, value, start_row, start_col + i, 1, 1, cell_color)
-
- return self
-
- def save(self, out_path: str) -> None:
- """Save the workbook to a file."""
- os.makedirs(os.path.dirname(out_path), exist_ok=True)
- self.wb.save(out_path)
-
-
-def visualize_profile_logs(
- all_profile_logs: np.ndarray,
- out_path: str,
- inst2opname: list[tuple[str, float, int] | tuple[str, float] | str],
- with_mean: bool = False,
- with_max: bool = False,
-) -> None:
- """Visualize profile logs."""
- valid_ctas = np.argwhere(np.any(all_profile_logs != 0, axis=(1, 2)))[:, 0]
- filtered_logs = all_profile_logs[valid_ctas]
- filtered_masks = np.logical_and(filtered_logs >= 0, filtered_logs < 1e9)
- mean_profile_logs = np.mean(filtered_logs, axis=0, where=filtered_masks)
- mean_profile_logs[np.isnan(mean_profile_logs)] = -1
- if filtered_logs.size == 0:
- return
- assemble_profile_logs = np.zeros_like(filtered_logs[0])
- assemble_profile_logs[:, ::2] = np.min(
- filtered_logs[..., ::2], axis=0, where=filtered_masks[..., ::2], initial=np.inf
- )
- assemble_profile_logs[:, 1::2] = np.max(
- filtered_logs[..., 1::2], axis=0, where=filtered_masks[..., 1::2], initial=-np.inf
- )
- assemble_profile_logs[np.isinf(assemble_profile_logs)] = -1
-
- visualizer = WorkerBookVisualizer(ExecPlanDescriptor(WORKER_NAMES, inst2opname))
- if with_mean:
- visualizer.add_sheet(mean_profile_logs, "mean")
- if with_max:
- raise NotImplementedError("with_max is not implemented")
-
- visualizer.add_sm_brief_sheet(filtered_logs, "mean_sm_brief")
- for block_idx, profile_logs in enumerate(filtered_logs):
- profile_logs[profile_logs > 1e9] = -1
- visualizer.add_sheet(profile_logs, f"block_{block_idx}")
- visualizer.save(out_path)
-
-
-def parse_profile_log_tensor(
- profile_logs_tensor: torch.Tensor,
- out_path: str,
- inst2opname: Any,
- with_mean: bool = False,
-) -> None:
- """Parse a profile log tensor into a dictionary.
-
- Args:
- profile_log_tensor: The profile log tensor.
- out_path: The path to save the profile logs.
- inst2opname: The mapping from instance index to operation name.
-
- list[tuple[str, float, int] | tuple[str, float] | str]
-
- Returns:
- None.
- """
- # Remove the extra slices for storing instructions and glb bars.
- profile_logs_tensor = profile_logs_tensor[:-SLICES_FOR_TILERT_OP, :, :]
-
- profile_logs = profile_logs_tensor.cpu().detach().numpy()
- valid_insts_logs = np.any(profile_logs != 0, axis=(1, 2))
- profile_logs = profile_logs[valid_insts_logs]
- valid_blocks_logs = np.any(profile_logs != 0, axis=(0, 2))
- profile_logs = profile_logs[:, valid_blocks_logs, :]
- # Return if no valid blocks logs are found.
- if profile_logs.size == 0:
- print("Warning: No profile logs available.")
- return
- profile_logs = np.transpose(profile_logs, (1, 0, 2))
- ctx_start_times = profile_logs[:, 0, 0]
- profile_logs = profile_logs[:, 1:, :]
- profile_logs = (profile_logs - ctx_start_times[:, None, None]).astype(np.float32) / 1.855
-
- if Workbook is not None:
- visualize_profile_logs(profile_logs, out_path, inst2opname, with_mean)
-
-
-def parse_op_time(profile_logs: torch.Tensor, op_idx: int = 0, block_idx: int = 0) -> None:
- data = profile_logs[op_idx, block_idx, :].cpu().numpy()
- max_time = data.max()
- start_time = data.min()
- FREQUENCY = 1850.0
-
- worker_names = [
- "controller",
- " sync_io",
- " io_p0",
- " io_p1",
- " io_p2",
- " consumer",
- " extra1",
- " extra2",
- ]
- for i, worker_name in enumerate(worker_names):
- if data[i * 2] != max_time:
- print(
- f"{worker_name}:\tstart:{(data[i * 2] - start_time) / FREQUENCY:.3f}, "
- f"duration:{(data[i * 2 + 1] - data[i * 2]) / FREQUENCY:.3f}, "
- f"end:{(data[i * 2 + 1] - start_time) / FREQUENCY:.3f}"
- )
diff --git a/requirements.txt b/requirements.txt
index ae9da40..fd4a9ba 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,3 +1,15 @@
-torch>=2.6.0
+# Runtime dependencies for the v0.1.4 wheel, pinned to the exact ABI the
+# wheel was built against. ``torch`` must be installed from PyTorch's cu130
+# index — PyPI's default ``torch`` is a different CUDA build and will not load
+# the cu130-linked tilert binary:
+#
+# pip install --index-url https://download.pytorch.org/whl/cu130 torch==2.11.0
+# pip install -r requirements.txt
+#
+# The recommended path remains the prebuilt Docker image (see README).
+torch==2.11.0
+transformers==4.46.3
+tokenizers==0.20.3
numpy
-transformers
+scipy
+einops
diff --git a/tilert/__init__.py b/tilert/__init__.py
new file mode 100644
index 0000000..d34ce51
--- /dev/null
+++ b/tilert/__init__.py
@@ -0,0 +1,91 @@
+"""TileRT Python package.
+
+Two backend libraries ship with TileRT — one per model family:
+
+ - ``libtilert_dsv32.so`` (DeepSeek-V3.2)
+ - ``libtilert_glm5.so`` (GLM-5)
+
+They are NOT loaded at import time. The caller selects a backend via
+``load_backend(model_type)`` (done automatically by ``tilert.generate``).
+Only one backend may be loaded per process — both register the ``tilert``
+torch-op namespace. Run DSv3.2 and GLM-5 in separate processes.
+"""
+
+import ctypes
+import logging
+import os
+from importlib.metadata import PackageNotFoundError
+from importlib.metadata import version as pkg_version
+from pathlib import Path
+
+import torch
+
+if not hasattr(torch, "ops"):
+ raise RuntimeError("PyTorch is required but torch.ops is not available")
+
+try:
+ __version__ = pkg_version("tilert")
+except PackageNotFoundError:
+ __version__ = "0.0.0"
+
+
+def init_logging() -> logging.Logger:
+ """Initialize logging configuration."""
+ logging.basicConfig(
+ level=logging.DEBUG,
+ format="%(filename)s:%(lineno)d [%(levelname)s]: %(message)s",
+ )
+ return logging.getLogger(__name__)
+
+
+logger = init_logging()
+
+_BACKENDS = {
+ "deepseek_v3_2": "libtilert_dsv32.so",
+ "glm5": "libtilert_glm5.so",
+}
+
+_loaded_backend: str | None = None
+
+
+def load_backend(model_type: str) -> None:
+ """Load the backend for ``model_type`` (lazy, once per process).
+
+ DeepSeek-V3.2 and GLM-5 ship as separate libraries; the matching one is
+ loaded on first use. Loading a second, different backend in the same
+ process raises (both libraries define the ``tilert`` op namespace).
+ """
+ global _loaded_backend
+ so_name = _BACKENDS.get(model_type)
+ if so_name is None:
+ raise ValueError(f"Unknown model_type {model_type!r}. Supported: {sorted(_BACKENDS)}")
+ if _loaded_backend is not None:
+ if _loaded_backend != so_name:
+ raise RuntimeError(
+ f"TileRT backend '{_loaded_backend}' already loaded; cannot load "
+ f"'{so_name}' in the same process. Run {model_type} in a fresh process."
+ )
+ return
+ pkg_dir = Path(__file__).parent
+ lib_path = pkg_dir / so_name
+ if not lib_path.exists():
+ fallback = pkg_dir / "libtilert.so"
+ if not fallback.exists():
+ raise RuntimeError(f"Backend library not found: {lib_path}.")
+ lib_path = fallback
+ ctypes.CDLL(str(lib_path), mode=ctypes.RTLD_GLOBAL | os.RTLD_LAZY)
+ torch.ops.load_library(str(lib_path))
+ _loaded_backend = so_name
+ logger.info(
+ "Loaded TileRT backend %s (%s) for model_type=%s", so_name, lib_path.name, model_type
+ )
+
+
+from .tilert_init import tilert_init # noqa: E402
+
+__all__ = [
+ "logger",
+ "load_backend",
+ "tilert_init",
+ "__version__",
+]
diff --git a/python/benchmark/__init__.py b/tilert/benchmark/__init__.py
similarity index 83%
rename from python/benchmark/__init__.py
rename to tilert/benchmark/__init__.py
index 49a349d..1194444 100644
--- a/python/benchmark/__init__.py
+++ b/tilert/benchmark/__init__.py
@@ -4,9 +4,8 @@
from typing import TypeAlias
from tilert.models.deepseek_v3_2.generator import DSAv32Generator
-from tilert.models.glm_5.generator import GLM5Generator
-Generator: TypeAlias = DSAv32Generator | GLM5Generator
+Generator: TypeAlias = DSAv32Generator
@dataclass
@@ -15,7 +14,6 @@ class BenchMode:
with_mtp: bool
label: str
- # Sampling parameters — None means keep current generator defaults (top-k1 argmax).
use_topp: bool = False
top_p: float = 1.0
top_k: int = 256
@@ -27,13 +25,25 @@ class CellStats:
"""Stats for a single table cell (one mode x one benchmark column)."""
tok_s: float = 0.0
- ms: float = 0.0
+ iters_s: str = "-"
acc_rate: str = "-"
BenchStats = dict[str, dict[str, CellStats]]
+@dataclass
+class PerStepData:
+ """Per-step timing data from a single generation run."""
+
+ prompt_len: int
+ time_list: list[float]
+ accepted_counts: list[int]
+
+
+PerStepDict = dict[str, dict[str, list[PerStepData]]]
+
+
def apply_mode(generator: Generator, mode: BenchMode) -> None:
"""Apply sampling parameters for a benchmark mode."""
generator.update_sampling_params(
@@ -68,16 +78,14 @@ def print_summary_table(
if not all_stats:
return
- # Collect column keys in insertion order (preserves benchmark ordering)
col_keys: list[str] = []
for cols in all_stats.values():
for k in cols:
if k not in col_keys:
col_keys.append(k)
- ROW_LABELS = ["tok/s", "ms", "acc"]
+ ROW_LABELS = ["tok/s", "it/s", "acc"]
- # Build formatted cell strings: {mode: {col: [row0, row1, row2]}}
formatted: dict[str, dict[str, list[str]]] = {}
for mode, cols in all_stats.items():
formatted[mode] = {}
@@ -88,11 +96,10 @@ def print_summary_table(
else:
formatted[mode][k] = [
_fmt(cell.tok_s, "tok/s"),
- _fmt(cell.ms, "ms"),
+ cell.iters_s,
cell.acc_rate,
]
- # Compute column widths
col_widths: dict[str, int] = {}
for k in col_keys:
w = len(k)
@@ -102,22 +109,18 @@ def print_summary_table(
col_widths[k] = w
mode_width = max(len("Mode"), max(len(m) for m in all_stats))
- # Row label column shares the mode column; pick wider of mode names vs row labels
mode_width = max(mode_width, max(len(r) for r in ROW_LABELS))
print(f"\n## Benchmark Summary ({model_name})\n")
- # Header
hdr = [f" {'Mode':<{mode_width}} "]
hdr += [f" {k:<{col_widths[k]}} " for k in col_keys]
print("|" + "|".join(hdr) + "|")
- # Separator
sep = ["-" * (mode_width + 2)]
sep += ["-" * (col_widths[k] + 2) for k in col_keys]
print("|" + "|".join(sep) + "|")
- # Data rows: 3 rows per mode
mode_list = list(all_stats.keys())
for _, mode in enumerate(mode_list):
for row_idx, _row_label in enumerate(ROW_LABELS):
diff --git a/python/benchmark/coding_prompt.py b/tilert/benchmark/coding_prompt.py
similarity index 54%
rename from python/benchmark/coding_prompt.py
rename to tilert/benchmark/coding_prompt.py
index e4ff6ed..1d98b34 100644
--- a/python/benchmark/coding_prompt.py
+++ b/tilert/benchmark/coding_prompt.py
@@ -3,17 +3,27 @@
from typing import cast
import numpy as np
-from benchmark import BenchMode, BenchStats, CellStats, Generator, apply_mode
+
+from tilert.benchmark import (
+ BenchMode,
+ BenchStats,
+ CellStats,
+ Generator,
+ PerStepData,
+ PerStepDict,
+ apply_mode,
+)
PROMPT = "Hi, can you write a sort program in C for me?"
-def run(generator: Generator, modes: list[BenchMode]) -> BenchStats:
+def run(generator: Generator, modes: list[BenchMode]) -> tuple[BenchStats, PerStepDict]:
"""Run the coding-prompt benchmark for each mode.
Returns stats with column: Coding.
"""
stats: BenchStats = {}
+ per_step: PerStepDict = {}
for mode in modes:
apply_mode(generator, mode)
@@ -21,8 +31,8 @@ def run(generator: Generator, modes: list[BenchMode]) -> BenchStats:
print(f"Prompt: {PROMPT}")
print("Completion:")
- _, time_list, accepted_counts = cast(
- tuple[str, list[float], list[int]],
+ _, time_list, accepted_counts, prompt_len = cast(
+ tuple[str, list[float], list[int], int],
generator.generate(PROMPT, True, with_mtp=mode.with_mtp),
)
@@ -32,15 +42,25 @@ def run(generator: Generator, modes: list[BenchMode]) -> BenchStats:
total_tokens = sum(accepted_counts)
total_time = sum(time_list)
speed = total_tokens / total_time if total_time > 0 else 0
- avg_ms = total_time / len(time_list) * 1000
avg_a = total_tokens / len(accepted_counts)
acc_rate = f"{avg_a:.2f}/{min(accepted_counts)}/{max(accepted_counts)}"
- mode_stats["Coding"] = CellStats(tok_s=speed, ms=avg_ms, acc_rate=acc_rate)
+ iters_s = len(time_list) / total_time if total_time > 0 else 0.0
+ mode_stats["Coding"] = CellStats(
+ tok_s=speed, iters_s=f"{iters_s:.1f} it/s", acc_rate=acc_rate
+ )
elif time_list:
mean_time = float(np.mean(time_list))
speed = 1 / mean_time
- mode_stats["Coding"] = CellStats(tok_s=speed, ms=mean_time * 1000)
+ mode_stats["Coding"] = CellStats(tok_s=speed, iters_s=f"{speed:.1f} it/s")
+
+ per_step[mode.label] = {
+ "Coding": [
+ PerStepData(
+ prompt_len=prompt_len, time_list=time_list, accepted_counts=accepted_counts
+ )
+ ]
+ }
stats[mode.label] = mode_stats
- return stats
+ return stats, per_step
diff --git a/tilert/benchmark/config.py b/tilert/benchmark/config.py
new file mode 100644
index 0000000..5f40628
--- /dev/null
+++ b/tilert/benchmark/config.py
@@ -0,0 +1,69 @@
+"""TileRT configuration file loading.
+
+Reads model weights paths from ~/.tilert/config.toml so that benchmark scripts
+and regression workflows do not need hardcoded paths.
+
+Config file format (~/.tilert/config.toml):
+
+ [weights]
+ deepseek_v3_2 = "/path/to/tilert_weights/DeepSeek-V32"
+ deepseek_v3_2_v2 = "/path/to/tilert_weights/DeepSeek-V32-v2"
+"""
+
+import tomllib
+from pathlib import Path
+
+CONFIG_DIR = Path.home() / ".tilert"
+CONFIG_FILE = CONFIG_DIR / "config.toml"
+
+
+def get_config_path() -> Path:
+ """Return the path to the TileRT config file."""
+ return CONFIG_FILE
+
+
+def get_weights_dir(model: str, cli_override: str | None = None) -> str:
+ """Resolve the weights directory for *model*.
+
+ Resolution order (highest priority first):
+ 1. *cli_override* (from ``--model-weights-dir`` CLI flag)
+ 2. ``~/.tilert/config.toml`` → ``[weights].``
+
+ Raises ``FileNotFoundError`` / ``KeyError`` with a user-friendly message
+ when the config file or key is missing.
+ """
+ if cli_override is not None:
+ return cli_override
+
+ config_path = get_config_path()
+ if not config_path.exists():
+ raise FileNotFoundError(
+ f"No --model-weights-dir provided and config file not found at {config_path}.\n"
+ f"Create it with:\n\n"
+ f" mkdir -p {CONFIG_DIR}\n"
+ f" cat > {config_path} << 'EOF'\n"
+ f" [weights]\n"
+ f' deepseek_v3_2 = "/path/to/DeepSeek-V32"\n'
+ f" EOF\n"
+ )
+
+ try:
+ with open(config_path, "rb") as f:
+ config = tomllib.load(f)
+ except tomllib.TOMLDecodeError as e:
+ raise ValueError(
+ f"Failed to parse {config_path}: {e}\n" f"Please check the file for syntax errors."
+ ) from e
+
+ weights = config.get("weights", {})
+ if model not in weights:
+ available = ", ".join(weights.keys()) if weights else "(none)"
+ raise KeyError(
+ f"Model {model!r} not found in {config_path} [weights] section.\n"
+ f"Available models: {available}\n"
+ f"Add it with:\n\n"
+ f" [weights]\n"
+ f' {model} = "/path/to/{model}/weights"\n'
+ )
+
+ return str(weights[model])
diff --git a/tilert/benchmark/long_prompt.py b/tilert/benchmark/long_prompt.py
new file mode 100644
index 0000000..7df175b
--- /dev/null
+++ b/tilert/benchmark/long_prompt.py
@@ -0,0 +1,82 @@
+"""Long-prompt benchmark: single generation, measures long-form throughput."""
+
+from typing import cast
+
+import numpy as np
+
+from tilert.benchmark import (
+ BenchMode,
+ BenchStats,
+ CellStats,
+ Generator,
+ PerStepData,
+ PerStepDict,
+ apply_mode,
+)
+
+PROMPT = "Hi, can you tell me a very long story, with roughly 3000 words?"
+
+
+def run(generator: Generator, modes: list[BenchMode]) -> tuple[BenchStats, PerStepDict]:
+ """Run the long-prompt benchmark for each mode.
+
+ Returns stats with column: Long.
+ """
+ stats: BenchStats = {}
+ per_step: PerStepDict = {}
+
+ for mode in modes:
+ apply_mode(generator, mode)
+ print(f"\n--- Long-prompt benchmark ({mode.label}) ---")
+ print(f"Prompt: {PROMPT}")
+ print("Completion:")
+
+ _, time_list, accepted_counts, prompt_len = cast(
+ tuple[str, list[float], list[int], int],
+ generator.generate(PROMPT, True, with_mtp=mode.with_mtp),
+ )
+
+ mode_stats: dict[str, CellStats] = {}
+
+ if mode.with_mtp and accepted_counts:
+ total_tokens = sum(accepted_counts)
+ total_time = sum(time_list)
+ speed = total_tokens / total_time if total_time > 0 else 0
+ avg_a = total_tokens / len(accepted_counts)
+ acc_rate = f"{avg_a:.2f}/{min(accepted_counts)}/{max(accepted_counts)}"
+
+ cumtok = list(np.cumsum(accepted_counts))
+ split_idx = next((i for i, t in enumerate(cumtok) if t >= 2048), len(time_list))
+ end_idx = next((i for i, t in enumerate(cumtok) if t >= 2048 + 512), len(time_list))
+ pre_time = time_list[:split_idx]
+ post_time = time_list[split_idx:end_idx]
+ pre_ips = len(pre_time) / sum(pre_time) if pre_time else 0.0
+ post_ips = len(post_time) / sum(post_time) if post_time else 0.0
+ iters_s = f"{pre_ips:.1f}/{post_ips:.1f} it/s"
+
+ mode_stats["Long"] = CellStats(tok_s=speed, iters_s=iters_s, acc_rate=acc_rate)
+ elif time_list:
+ mean_time = float(np.mean(time_list))
+ speed = 1 / mean_time
+
+ split_idx = min(2048, len(time_list))
+ end_idx = min(2048 + 512, len(time_list))
+ pre_time = time_list[:split_idx]
+ post_time = time_list[split_idx:end_idx]
+ pre_ips = len(pre_time) / sum(pre_time) if pre_time else 0.0
+ post_ips = len(post_time) / sum(post_time) if post_time else 0.0
+ iters_s = f"{pre_ips:.1f}/{post_ips:.1f} it/s"
+
+ mode_stats["Long"] = CellStats(tok_s=speed, iters_s=iters_s)
+
+ per_step[mode.label] = {
+ "Long": [
+ PerStepData(
+ prompt_len=prompt_len, time_list=time_list, accepted_counts=accepted_counts
+ )
+ ]
+ }
+
+ stats[mode.label] = mode_stats
+
+ return stats, per_step
diff --git a/python/benchmark/short_prompt.py b/tilert/benchmark/short_prompt.py
similarity index 66%
rename from python/benchmark/short_prompt.py
rename to tilert/benchmark/short_prompt.py
index bebd2ce..4bdebe2 100644
--- a/python/benchmark/short_prompt.py
+++ b/tilert/benchmark/short_prompt.py
@@ -1,41 +1,59 @@
-"""Short-prompt benchmark: 20 iterations, measures steady-state decode throughput."""
+"""Short-prompt benchmark: 1 warmup + 20 iterations, measures steady-state decode throughput."""
from typing import cast
import numpy as np
-from benchmark import BenchMode, BenchStats, CellStats, Generator, apply_mode
+
+from tilert.benchmark import (
+ BenchMode,
+ BenchStats,
+ CellStats,
+ Generator,
+ PerStepData,
+ PerStepDict,
+ apply_mode,
+)
PROMPT = "Tell me 10 jokes, keep them all under 100 words."
NUM_ITERS = 20
TOKEN_CHECKPOINTS = [200]
-def run(generator: Generator, modes: list[BenchMode]) -> BenchStats:
+def run(generator: Generator, modes: list[BenchMode]) -> tuple[BenchStats, PerStepDict]:
"""Run the short-prompt benchmark for each mode.
Returns stats with columns: Short@ for each checkpoint.
"""
stats: BenchStats = {}
+ per_step: PerStepDict = {}
for mode in modes:
apply_mode(generator, mode)
print(f"\n--- Short-prompt benchmark ({mode.label}) ---", flush=True)
+ print(" warmup...", flush=True)
+ generator.generate(PROMPT, False, with_mtp=mode.with_mtp)
+
all_times: list[list[float]] = []
all_accepted: list[list[int]] = []
all_results: list[str] = []
+ all_per_step_data: list[PerStepData] = []
for _iter in range(NUM_ITERS):
if _iter % 5 == 0:
print(f" iter {_iter}/{NUM_ITERS}...", flush=True)
- result, time_list, accepted_counts = cast(
- tuple[str, list[float], list[int]],
+ result, time_list, accepted_counts, prompt_len = cast(
+ tuple[str, list[float], list[int], int],
generator.generate(PROMPT, False, with_mtp=mode.with_mtp),
)
all_times.append(time_list)
all_accepted.append(accepted_counts)
all_results.append(result)
+ all_per_step_data.append(
+ PerStepData(
+ prompt_len=prompt_len, time_list=time_list, accepted_counts=accepted_counts
+ )
+ )
- # Verify determinism and print output once
mismatches = [i for i, r in enumerate(all_results) if r != all_results[0]]
if mismatches:
print(f" WARNING: non-deterministic output at iters {mismatches}")
@@ -47,21 +65,21 @@ def run(generator: Generator, modes: list[BenchMode]) -> BenchStats:
if mode.with_mtp:
for token_num in TOKEN_CHECKPOINTS:
speeds: list[float] = []
+ iter_rates: list[float] = []
for time_list, accepted_list in zip(all_times, all_accepted):
if time_list and accepted_list:
cumsum_tokens = np.cumsum(accepted_list)
cumsum_times = np.cumsum(time_list)
idx = int(np.searchsorted(cumsum_tokens, token_num))
- # If total tokens < token_num, use all available data
if idx >= len(cumsum_times):
idx = len(cumsum_times) - 1
tok_count = int(cumsum_tokens[idx])
elapsed = float(cumsum_times[idx])
if elapsed > 0:
speeds.append(tok_count / elapsed)
+ iter_rates.append((idx + 1) / elapsed)
if speeds:
speed = float(np.mean(speeds))
- mean_time = 1 / speed
flat_accepted = [a for al in all_accepted for a in al]
acc_rate = "-"
@@ -69,8 +87,11 @@ def run(generator: Generator, modes: list[BenchMode]) -> BenchStats:
avg_a = sum(flat_accepted) / len(flat_accepted)
acc_rate = f"{avg_a:.2f}/{min(flat_accepted)}/{max(flat_accepted)}"
+ iters_s = float(np.mean(iter_rates)) if iter_rates else 0.0
mode_stats[f"Short@{token_num}"] = CellStats(
- tok_s=speed, ms=mean_time * 1000, acc_rate=acc_rate
+ tok_s=speed,
+ iters_s=f"{iters_s:.1f} it/s",
+ acc_rate=acc_rate,
)
else:
for token_num in TOKEN_CHECKPOINTS:
@@ -82,8 +103,15 @@ def run(generator: Generator, modes: list[BenchMode]) -> BenchStats:
if per_token_times:
mean_time = float(np.mean(per_token_times))
speed = 1 / mean_time
- mode_stats[f"Short@{token_num}"] = CellStats(tok_s=speed, ms=mean_time * 1000)
+ mode_stats[f"Short@{token_num}"] = CellStats(
+ tok_s=speed, iters_s=f"{speed:.1f} it/s"
+ )
+
+ mode_per_step: dict[str, list[PerStepData]] = {}
+ for token_num in TOKEN_CHECKPOINTS:
+ mode_per_step[f"Short@{token_num}"] = all_per_step_data
+ per_step[mode.label] = mode_per_step
stats[mode.label] = mode_stats
- return stats
+ return stats, per_step
diff --git a/tilert/generate.py b/tilert/generate.py
new file mode 100644
index 0000000..bfcd97f
--- /dev/null
+++ b/tilert/generate.py
@@ -0,0 +1,299 @@
+"""Text generation script for TileRT."""
+
+import time
+from argparse import ArgumentParser
+from typing import TYPE_CHECKING
+
+import tilert
+
+if TYPE_CHECKING:
+ from tilert.models.deepseek_v3_2.generator import DSAv32Generator
+ from tilert.models.glm_5.generator import GLM5Generator
+from tilert.benchmark import BenchMode
+from tilert.benchmark import coding_prompt as coding_bench
+from tilert.benchmark import long_prompt as long_bench
+from tilert.benchmark import merge_stats, print_summary_table
+from tilert.benchmark import short_prompt as short_bench
+from tilert.benchmark.config import get_weights_dir
+
+
+def get_generator(
+ model_type: str,
+ max_new_tokens: int,
+ temperature: float,
+ model_weights_dir: str,
+ with_mtp: bool,
+ top_p: float = 0.9,
+ top_k: int = 256,
+ enable_thinking: bool = False,
+ sampling_seed: int = 42,
+) -> "DSAv32Generator | GLM5Generator":
+ """Load the matching backend .so and build the generator for ``model_type``.
+
+ DeepSeek-V3.2 and GLM-5 ship as separate libraries; only one backend loads
+ per process. Generators are imported lazily after the backend is loaded.
+ """
+ tilert.load_backend(model_type)
+
+ if model_type == "deepseek_v3_2":
+ from tilert.models.deepseek_v3_2.generator import DSAv32Generator
+ from tilert.models.deepseek_v3_2.model_args import ModelArgs as DSAv32ModelArgs
+
+ return DSAv32Generator(
+ model_args=DSAv32ModelArgs(),
+ max_new_tokens=max_new_tokens,
+ temperature=temperature,
+ model_weights_dir=model_weights_dir,
+ with_mtp=with_mtp,
+ top_p=top_p,
+ top_k=top_k,
+ use_topp=top_p < 1.0,
+ sampling_seed=sampling_seed,
+ enable_thinking=enable_thinking,
+ )
+
+ if model_type == "glm5":
+ from tilert.models.glm_5.generator import GLM5Generator
+ from tilert.models.glm_5.model_args import ModelArgsGLM5
+
+ return GLM5Generator(
+ model_args=ModelArgsGLM5(),
+ max_new_tokens=max_new_tokens,
+ temperature=temperature,
+ model_weights_dir=model_weights_dir,
+ with_mtp=with_mtp,
+ top_p=top_p,
+ top_k=top_k,
+ use_topp=top_p < 1.0,
+ enable_thinking=enable_thinking,
+ sampling_seed=sampling_seed,
+ )
+
+ raise ValueError(f"unsupported model_type: {model_type!r}")
+
+
+def parse_args(): # type: ignore
+ parser = ArgumentParser(description="Command-line interface for text generation.")
+ parser.add_argument(
+ "--model-weights-dir",
+ type=str,
+ default=None,
+ help="Path to model weights directory (resolved from ~/.tilert/config.toml if omitted)",
+ )
+ parser.add_argument(
+ "--model",
+ type=str,
+ default="deepseek_v3_2",
+ choices=["deepseek_v3_2", "glm5"],
+ help="Model type to use (default: deepseek_v3_2).",
+ )
+ parser.add_argument("--max-new-tokens", type=int, default=4000, help="Max tokens to generate")
+ parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature")
+ parser.add_argument(
+ "--top-p",
+ type=float,
+ default=1.0,
+ help="Top-p (nucleus) sampling threshold. Use < 1.0 to enable top-p sampling (e.g. 0.9)",
+ )
+ parser.add_argument("--top-k", type=int, default=256, help="Top-k sampling threshold")
+ parser.add_argument("--interactive", action="store_true")
+ parser.add_argument(
+ "--with-mtp",
+ action="store_true",
+ help="Enable MTP (Multi-Token Prediction) for speculative decoding",
+ )
+ parser.add_argument(
+ "--use-random-weights",
+ action="store_true",
+ help="Use random weights instead of pretrained (for testing MTP without real weights)",
+ )
+ parser.add_argument(
+ "--enable-thinking",
+ action="store_true",
+ help="Enable thinking mode in chat template",
+ )
+ parser.add_argument(
+ "--sampling-seed",
+ type=int,
+ default=42,
+ help="Sampling seed for top-p sampling (fixed per request, default: 42)",
+ )
+ parser.add_argument(
+ "--model-name",
+ type=str,
+ default=None,
+ help="Override display name for benchmark tables",
+ )
+ parser.add_argument(
+ "--tag",
+ type=str,
+ default=None,
+ help="Tag for regression_plots/ directory (default: auto-detect from git state)",
+ )
+ parser.add_argument(
+ "--modes",
+ type=str,
+ default=None,
+ help="Comma-separated mode filters: top-k1,top-p0.95 (default: all)",
+ )
+ parser.add_argument(
+ "--workloads",
+ type=str,
+ default=None,
+ help="Comma-separated workload filters: short,coding,long (default: all)",
+ )
+ parser.add_argument(
+ "--enable-logprobs",
+ action="store_true",
+ help="Enable kernel-level top-256 logprobs export (for benchmarking overhead)",
+ )
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ """
+ Usage (run as a module; --model-weights-dir may be omitted if the path is
+ registered under ~/.tilert/config.toml). Run DeepSeek-V3.2 and GLM-5 in
+ separate processes — the two backends cannot coexist in one interpreter.
+
+ # DeepSeek-V3.2 — standard generation with pretrained weights:
+ python -m tilert.generate --model deepseek_v3_2 \
+ --model-weights-dir /path/to/DeepSeek-V3.2-TileRT \
+ --max-new-tokens 1000 2>&1 | tee test.log
+
+ # DeepSeek-V3.2 — MTP generation with random weights (for testing):
+ python -m tilert.generate --model deepseek_v3_2 --with-mtp --use-random-weights \
+ --model-weights-dir /path/to/DeepSeek-V3.2-TileRT \
+ --max-new-tokens 1000 2>&1 | tee test.log
+
+ # DeepSeek-V3.2 — MTP generation with pretrained weights:
+ python -m tilert.generate --model deepseek_v3_2 --with-mtp \
+ --model-weights-dir /path/to/DeepSeek-V3.2-TileRT \
+ --max-new-tokens 1000 2>&1 | tee test.log
+
+ # GLM-5 — standard generation:
+ python -m tilert.generate --model glm5 \
+ --model-weights-dir /path/to/GLM-5-FP8-TileRT \
+ --max-new-tokens 1000 2>&1 | tee test.log
+
+ # GLM-5 — MTP generation:
+ python -m tilert.generate --model glm5 --with-mtp \
+ --model-weights-dir /path/to/GLM-5-FP8-TileRT \
+ --max-new-tokens 1000 2>&1 | tee test.log
+ """
+ args = parse_args()
+
+ config_key = args.model
+ model_name = args.model.upper()
+ if args.model_name:
+ model_name = args.model_name
+ model_weights_dir = get_weights_dir(config_key, cli_override=args.model_weights_dir)
+
+ if args.interactive:
+ with_mtp = args.with_mtp
+ else:
+ with_mtp = True
+
+ generator = get_generator(
+ model_type=args.model,
+ max_new_tokens=args.max_new_tokens,
+ temperature=args.temperature,
+ model_weights_dir=model_weights_dir,
+ with_mtp=with_mtp,
+ top_p=args.top_p,
+ top_k=args.top_k,
+ enable_thinking=args.enable_thinking,
+ sampling_seed=args.sampling_seed,
+ )
+
+ t0 = time.monotonic()
+ if args.use_random_weights:
+ print("Initializing random weights...")
+ if hasattr(generator, "init"):
+ generator.init() # type: ignore[union-attr]
+ generator.init_random_weights()
+ else:
+ print("Loading pretrained weights...")
+ generator.from_pretrained()
+ load_time = time.monotonic() - t0
+
+ if args.enable_logprobs:
+ if hasattr(generator.decode_layer, "set_logprobs_enabled"):
+ generator.decode_layer.set_logprobs_enabled(True) # type: ignore[union-attr]
+ print("Logprobs export enabled (top-256)")
+ else:
+ print(f"Warning: logprobs not supported for {type(generator).__name__}")
+
+ if args.interactive:
+ print("Welcome to the TileRT interactive mode! Type '/exit' to exit.")
+ while True:
+ prompt = input(">>> ")
+ if prompt == "/exit":
+ break
+ _ = generator.generate(prompt) # type: ignore[has-type]
+ else:
+
+ bench_top_p = args.top_p if args.top_p < 1.0 else 0.95
+ modes = [
+ BenchMode(with_mtp=False, label="top-k1 w/o MTP"),
+ BenchMode(with_mtp=True, label="top-k1 w/ MTP"),
+ BenchMode(
+ with_mtp=True,
+ label=f"top-p{bench_top_p} w/ MTP",
+ use_topp=True,
+ top_p=bench_top_p,
+ top_k=args.top_k,
+ temperature=args.temperature,
+ ),
+ ]
+
+ if args.modes:
+ allowed = {m.strip() for m in args.modes.split(",")}
+ modes = [m for m in modes if any(a in m.label for a in allowed)]
+ if not modes:
+ raise SystemExit(
+ f"Error: --modes '{args.modes}' matched no benchmark modes. "
+ f"Valid tokens: top-k1, top-p0.95"
+ )
+
+ t0 = time.monotonic()
+ workload_runners = []
+ allowed_workloads = (
+ {w.strip() for w in args.workloads.split(",")}
+ if args.workloads
+ else {"short", "coding", "long"}
+ )
+ if "short" in allowed_workloads:
+ workload_runners.append(short_bench.run)
+ if "coding" in allowed_workloads:
+ workload_runners.append(coding_bench.run)
+ if "long" in allowed_workloads:
+ workload_runners.append(long_bench.run)
+ if not workload_runners:
+ raise SystemExit(
+ f"Error: --workloads '{args.workloads}' matched no workloads. "
+ f"Valid values: short, coding, long"
+ )
+
+ all_bench_results = [
+ runner(generator, modes) for runner in workload_runners # type: ignore[arg-type]
+ ]
+ bench_time = time.monotonic() - t0
+ all_bench_stats = [stats for stats, _ in all_bench_results]
+
+ print_summary_table(
+ merge_stats(all_bench_stats),
+ model_name=model_name,
+ )
+
+ total = load_time + bench_time
+ print(f"\n## {model_name} Timing")
+ print()
+ print("| Phase | Time |")
+ print("|-------|------|")
+ print(f"| Loading | {load_time:.1f}s |")
+ print(f"| Benchmark | {bench_time:.1f}s |")
+ print(f"| **Total** | **{total:.1f}s** |")
+
+ print("Cleaning up...")
+ generator.cleanup()
diff --git a/python/models/__init__.py b/tilert/models/__init__.py
similarity index 100%
rename from python/models/__init__.py
rename to tilert/models/__init__.py
diff --git a/python/models/base.py b/tilert/models/base.py
similarity index 83%
rename from python/models/base.py
rename to tilert/models/base.py
index 58171a7..b9f5d4d 100644
--- a/python/models/base.py
+++ b/tilert/models/base.py
@@ -3,7 +3,7 @@
import os
from abc import ABC, abstractmethod
from enum import Enum
-from typing import Any
+from typing import Any, ClassVar
import torch
import torch.nn as nn
@@ -17,11 +17,13 @@
"TileRTModule",
]
+ModelArgsLike = Any
+
class TilertWeightsConverter:
"""Tilert weights converter"""
- def __init__(self, model_args: ModelArgs, num_devices: int):
+ def __init__(self, model_args: ModelArgsLike, num_devices: int):
self.model_args = model_args
self.num_devices = num_devices
@@ -38,6 +40,29 @@ class TileRTModule(nn.Module, ABC):
own forward method.
"""
+ _SUPPORTED_ALGORITHMS: ClassVar[dict[str, list[Enum]]] = {}
+ _VALID_COMPUTE_KERNEL_TYPES: ClassVar[frozenset[str]] = frozenset(
+ {
+ "bf16",
+ "fp8",
+ "fp8mma",
+ "general",
+ "bf16mma",
+ "fp16mma",
+ "fp8mma_68cta",
+ }
+ )
+
+ @classmethod
+ def get_supported_algorithms(cls, arch_name: str) -> list[Enum]:
+ """Return supported algorithms for the given architecture."""
+ if arch_name not in cls._SUPPORTED_ALGORITHMS:
+ raise ValueError(
+ f"{cls.__name__} does not support arch '{arch_name}'. "
+ f"Supported: {list(cls._SUPPORTED_ALGORITHMS.keys())}"
+ )
+ return cls._SUPPORTED_ALGORITHMS[arch_name]
+
def __init__(
self,
op_name: str = "",
@@ -45,7 +70,7 @@ def __init__(
tilert_weights_dir: str = "",
layer_idx: int = 0,
compute_kernel_type: str = "bf16",
- model_args: ModelArgs | None = None,
+ model_args: ModelArgsLike | None = None,
num_devices: int = 8,
device_id: int = 0,
*args: Any,
@@ -64,7 +89,7 @@ def __init__(
"""
super().__init__(*args, **kwargs)
- self.model_args = model_args if model_args is not None else ModelArgs()
+ self.model_args: ModelArgsLike = model_args if model_args is not None else ModelArgs()
self.num_devices = num_devices
self.device_id = device_id
self.algorithm: Enum | None = None
@@ -79,10 +104,10 @@ def __init__(
self.flag_enable_tilert = False
- if compute_kernel_type not in ["bf16", "fp8", "fp8mma"]:
+ if compute_kernel_type not in self._VALID_COMPUTE_KERNEL_TYPES:
raise ValueError(
- f"Invalid compute kernel type: {compute_kernel_type}, \
- must be one of bf16, fp8, fp8mma."
+ f"Invalid compute kernel type: {compute_kernel_type}, "
+ f"must be one of {sorted(self._VALID_COMPUTE_KERNEL_TYPES)}."
)
self.compute_kernel_type = compute_kernel_type
@@ -112,6 +137,14 @@ def set_algorithm(self, algorithm: Enum) -> None:
Args:
algorithm: Algorithm.
"""
+ if self._SUPPORTED_ALGORITHMS:
+ arch = self.model_args.arch_name
+ supported = self.get_supported_algorithms(arch)
+ if algorithm not in supported:
+ raise ValueError(
+ f"{type(self).__name__}: algorithm {algorithm} not supported "
+ f"for arch '{arch}'. Supported: {supported}"
+ )
self.algorithm = algorithm
def register_weights(self, weights_config: dict[str, dict[str, Any]]) -> None:
@@ -214,7 +247,11 @@ class SerializableTileRTModule(TileRTModule):
"""Serializable TileRT module."""
def __init__(
- self, model_args: ModelArgs, device_id: int, num_devices: int, remove_selected: bool = False
+ self,
+ model_args: ModelArgsLike,
+ device_id: int,
+ num_devices: int,
+ remove_selected: bool = False,
):
super().__init__(
type(self).__name__, model_args=model_args, device_id=device_id, num_devices=num_devices
@@ -284,14 +321,21 @@ def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
for op, prefix, suffix, retain_weights in zip(
self.exec_seq, self.prefix_seq, self.suffix_seq, self.retain_weights_seq
):
+ if op.is_tilert_weights_init:
+ logger.debug(f"Skipping init_tilert_weights for {op.op_name} (already initialized)")
+ continue
+
keys_to_remove = set()
op_state_dict = {}
for op_key in op.get_tilert_weights_alias():
original_key = f"{prefix}{op_key}{suffix}"
- op_state_dict[op_key] = state_dict[original_key]
- if self.remove_selected:
- keys_to_remove.add(original_key)
+ if original_key in state_dict:
+ op_state_dict[op_key] = state_dict[original_key]
+ if self.remove_selected:
+ keys_to_remove.add(original_key)
+
op.init_tilert_weights(op_state_dict)
+
if self.remove_selected and not retain_weights:
for k in keys_to_remove:
del state_dict[k]
diff --git a/tilert/models/common.py b/tilert/models/common.py
new file mode 100644
index 0000000..6d6f436
--- /dev/null
+++ b/tilert/models/common.py
@@ -0,0 +1,133 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, cast
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+if TYPE_CHECKING:
+ from tilert.models.deepseek_v3_2.refs.kernel import act_quant, fp8_gemm, weight_dequant
+
+__all__ = [
+ "act_quant",
+ "fp8_gemm",
+ "weight_dequant",
+ "init_func",
+ "linear",
+ "RMSNorm",
+]
+
+from tilert.models.deepseek_config import (
+ block_size,
+ gemm_impl,
+)
+
+_LAZY_IMPORTS = {"act_quant", "fp8_gemm", "weight_dequant"}
+
+
+def __getattr__(name: str) -> object:
+ if name in _LAZY_IMPORTS:
+ from tilert.models.deepseek_v3_2.refs import kernel
+
+ attr = getattr(kernel, name)
+ globals()[name] = attr
+ return attr
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
+
+
+def _get_scale_tensor(tensor: torch.Tensor) -> torch.Tensor:
+ """Return the dynamically attached ``scale`` tensor."""
+ scale = getattr(tensor, "scale", None)
+ if scale is None:
+ raise AttributeError("Expected quantized tensor to carry a 'scale' attribute.")
+ return cast(torch.Tensor, scale)
+
+
+def init_func(x_in: torch.Tensor) -> torch.Tensor:
+ x_dtype = x_in.dtype
+ x_fp32 = x_in.to(torch.float32)
+ if x_fp32.dim() >= 2:
+ initial_tensor = nn.init.kaiming_uniform_(x_fp32)
+ else:
+ initial_tensor = nn.init.uniform_(x_fp32)
+ return initial_tensor.to(x_dtype)
+
+
+def linear(
+ x_in: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None = None,
+ scale_fmt: str | None = None,
+) -> torch.Tensor:
+ """
+ Applies a linear transformation to the incoming data: y = xA^T + b.
+
+ Args:
+ x_in (torch.Tensor): The input tensor.
+ weight (torch.Tensor): The weight tensor. It may be quantized.
+ bias (Optional[torch.Tensor]): The bias tensor to be added. Default is None.
+
+ Returns:
+ torch.Tensor: The result of the linear transformation.
+ """
+ if weight.element_size() > 1:
+ return F.linear(x_in, weight, bias)
+
+ from tilert.models.deepseek_v3_2.refs.kernel import act_quant, fp8_gemm, weight_dequant
+
+ if gemm_impl == "bf16":
+ weight = weight_dequant(weight, _get_scale_tensor(weight))
+ return F.linear(x_in, weight, bias)
+
+ x_quant: torch.Tensor
+ scale: torch.Tensor
+ x_quant, scale = act_quant(x_in, block_size, scale_fmt)
+ y_out: torch.Tensor = fp8_gemm(x_quant, scale, weight, _get_scale_tensor(weight))
+ if bias is not None:
+ y_out += bias
+ return y_out
+
+
+class RMSNorm(nn.Module):
+ """
+ Root Mean Square Layer Normalization (RMSNorm).
+
+ Args:
+ dim (int): Dimension of the input tensor.
+ eps (float): Epsilon value for numerical stability. Defaults to 1e-6.
+ """
+
+ def __init__(self, dim: int, eps: float = 1e-6, weight: torch.Tensor | None = None):
+ super().__init__()
+ self.dim = dim
+ self.eps = eps
+
+ if weight is None:
+ self.weight = nn.Parameter(init_func(torch.empty(dim, dtype=torch.float32)))
+ else:
+ self.weight = torch.nn.Parameter(weight)
+
+ def forward(
+ self, x: torch.Tensor, residual: torch.Tensor | None = None
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
+ """
+ Forward pass for RMSNorm.
+
+ Args:
+ x (torch.Tensor): Input tensor.
+
+ Returns:
+ torch.Tensor: Normalized tensor with the same shape as input.
+ """
+ dtype = torch.bfloat16
+ if residual is None:
+ x = x.float()
+ var_s = x.pow(2).mean(-1, keepdim=True)
+ x = x * torch.rsqrt(var_s + self.eps)
+ return (self.weight * x).to(dtype)
+
+ x = residual = x.float() + residual.float()
+ var_s = x.pow(2).mean(-1, keepdim=True)
+ x = x * torch.rsqrt(var_s + self.eps)
+ return (self.weight * x).to(dtype), residual.to(dtype)
diff --git a/tilert/models/deepseek_config.py b/tilert/models/deepseek_config.py
new file mode 100644
index 0000000..ec3701d
--- /dev/null
+++ b/tilert/models/deepseek_config.py
@@ -0,0 +1,24 @@
+"""Global configuration for DeepSeek models."""
+
+from typing import Literal
+
+import torch.distributed as dist
+
+__all__ = [
+ "get_world_size",
+ "get_rank",
+ "block_size",
+ "gemm_impl",
+]
+
+
+def get_world_size() -> int:
+ return dist.get_world_size() if dist.is_initialized() else 8
+
+
+def get_rank() -> int:
+ return dist.get_rank() if dist.is_initialized() else 0
+
+
+block_size = 128
+gemm_impl: Literal["bf16", "fp8"] = "bf16"
diff --git a/python/models/deepseek_v3_2/__init__.py b/tilert/models/deepseek_v3_2/__init__.py
similarity index 100%
rename from python/models/deepseek_v3_2/__init__.py
rename to tilert/models/deepseek_v3_2/__init__.py
diff --git a/python/models/deepseek_v3_2/generator.py b/tilert/models/deepseek_v3_2/generator.py
similarity index 81%
rename from python/models/deepseek_v3_2/generator.py
rename to tilert/models/deepseek_v3_2/generator.py
index 3813259..fb7a467 100644
--- a/python/models/deepseek_v3_2/generator.py
+++ b/tilert/models/deepseek_v3_2/generator.py
@@ -40,6 +40,7 @@ def __init__(
top_p: float = 0.9,
top_k: int = 256,
sampling_seed: int = 42,
+ enable_thinking: bool = False,
):
"""Initialize the DSAv32Generator.
@@ -52,6 +53,8 @@ def __init__(
top_p: Top-p threshold for nucleus sampling. Defaults to 0.9.
top_k: Number of top-k candidates for top-p sampling. Defaults to 256.
sampling_seed: Sampling seed for top-p (fixed per request). Defaults to 42.
+ enable_thinking: Whether to enable thinking mode in the chat template.
+ Maps to the DSv32 tokenizer's ``thinking`` Jinja variable.
"""
torch.set_num_threads(64)
self.model_weights_dir = model_weights_dir
@@ -63,13 +66,14 @@ def __init__(
self.top_p = top_p
self.top_k = top_k
self.sampling_seed = sampling_seed
+ self.enable_thinking = enable_thinking
self.config = model_args
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_weights_dir, trust_remote_code=True
) # nosec B615
self.eos_id = self.tokenizer.eos_token_id
- self.batch_size = 1 # fixed batch size to 1 for now
+ self.batch_size = 1
self.default_device = torch.device("cuda:0")
@@ -100,6 +104,37 @@ def from_pretrained(self) -> None:
"""Load the model weights from the given path."""
self.decode_layer.from_pretrained(self.model_weights_dir)
+ def extract_ffn_cache(self) -> tuple[dict[int, list], dict[int, set[str]]]:
+ """Extract MOE/MLP op objects and skip keys from current loaded weights.
+
+ Returns:
+ Tuple of (cached_ffn_ops_per_device, skip_keys_per_device).
+ """
+ from tilert.models.deepseek_v3_2.modules.end2end import (
+ _extract_ffn_ops,
+ _get_moe_weight_keys,
+ )
+
+ cached_ffn_ops: dict[int, list] = {}
+ skip_keys: dict[int, set[str]] = {}
+ for device_id in range(self.decode_layer.num_devices):
+ dsa = self.decode_layer._dsa_objects[device_id]
+ if dsa is None:
+ raise RuntimeError(f"Device {device_id} Dsa not available for cache extraction")
+ cached_ffn_ops[device_id] = _extract_ffn_ops(dsa)
+ skip_keys[device_id] = _get_moe_weight_keys(dsa)
+ return cached_ffn_ops, skip_keys
+
+ def from_pretrained_with_cache(
+ self,
+ cached_ffn_ops_per_device: dict[int, list],
+ skip_keys_per_device: dict[int, set[str]],
+ ) -> None:
+ """Load weights reusing cached MOE/MLP ops."""
+ self.decode_layer.from_pretrained_with_cache(
+ self.model_weights_dir, cached_ffn_ops_per_device, skip_keys_per_device
+ )
+
def update_sampling_params(
self,
temperature: float = 1.0,
@@ -123,7 +158,7 @@ def generate(
print_log: bool = True,
with_mtp: bool | None = None,
prompt_tokens: list[int] | None = None,
- ) -> tuple[str, list[float], list[int]]:
+ ) -> tuple[str, list[float], list[int], int]:
"""Main function to load the model and perform single sequence generation.
Args:
@@ -135,7 +170,7 @@ def generate(
and use these tokens directly (useful for exact-length benchmarking).
Returns:
- Tuple of (result_text, time_list, accepted_counts).
+ Tuple of (result_text, time_list, accepted_counts, prompt_len).
accepted_counts is empty for non-MTP mode.
"""
active_mtp = with_mtp if with_mtp is not None else self.with_mtp
@@ -144,10 +179,10 @@ def generate(
self.decode_layer.set_sampling_seed(self.sampling_seed, with_mtp=active_mtp)
if active_mtp:
return self._generate_with_mtp(prompt, print_log, prompt_tokens=prompt_tokens)
- result, time_list = self._generate_without_mtp(
+ result, time_list, prompt_len = self._generate_without_mtp(
prompt, print_log, with_mtp=active_mtp, prompt_tokens=prompt_tokens
)
- return result, time_list, [] # Empty accepted_counts for non-MTP
+ return result, time_list, [], prompt_len
def _generate_without_mtp(
self,
@@ -155,17 +190,15 @@ def _generate_without_mtp(
print_log: bool = True,
with_mtp: bool = False,
prompt_tokens: list[int] | None = None,
- ) -> tuple[str, list[float]]:
+ ) -> tuple[str, list[float], int]:
"""Standard generation without MTP."""
if prompt_tokens is None:
prompt_tokens = self.tokenizer.apply_chat_template(
- [{"role": "user", "content": prompt}], add_generation_prompt=True
+ [{"role": "user", "content": prompt}],
+ add_generation_prompt=True,
+ thinking=self.enable_thinking,
)
- # adapt to transformers 5.2.0
- if not isinstance(prompt_tokens, list) and prompt_tokens.get("input_ids") is not None:
- prompt_tokens = prompt_tokens["input_ids"]
- assert prompt_tokens is not None
max_seq_len = self.config.max_seq_len
prompt_len = len(prompt_tokens)
total_len = min(max_seq_len, self.max_new_tokens + prompt_len)
@@ -193,9 +226,8 @@ def _generate_without_mtp(
time_list.append(end_time - start_time)
intermediates, *_ = multi_devices_results[0]
- next_token = intermediates[Idx.TOKEN_OUT][0][0] # only the first token
+ next_token = intermediates[Idx.TOKEN_OUT][0][0]
- # replace the next token with the prompt token if the prompt mask is True
next_token = torch.where(
prompt_mask[0, cur_pos_val], tokens[0, cur_pos_val], next_token
)
@@ -219,7 +251,6 @@ def _generate_without_mtp(
stats_time(time_list, "==== Performance ====")
print("\n")
- # Reset sequence after generation, i.e. reset the cur_pos to 0 internally
self.decode_layer.reset_sequence()
completion_tokens = []
@@ -231,29 +262,26 @@ def _generate_without_mtp(
decoded_tokens = self.tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
- return f"{decoded_tokens[0]}\n" if decoded_tokens else "", time_list
+ return f"{decoded_tokens[0]}\n" if decoded_tokens else "", time_list, prompt_len
def _generate_with_mtp(
self,
prompt: str,
print_log: bool = True,
prompt_tokens: list[int] | None = None,
- ) -> tuple[str, list[float], list[int]]:
+ ) -> tuple[str, list[float], list[int], int]:
"""Generation with MTP (Multi-Token Prediction) speculative decoding."""
if prompt_tokens is None:
prompt_tokens = self.tokenizer.apply_chat_template(
- [{"role": "user", "content": prompt}], add_generation_prompt=True
+ [{"role": "user", "content": prompt}],
+ add_generation_prompt=True,
+ thinking=self.enable_thinking,
)
- # adapt to transformers 5.2.0
- if not isinstance(prompt_tokens, list) and prompt_tokens.get("input_ids") is not None:
- prompt_tokens = prompt_tokens["input_ids"]
- assert prompt_tokens is not None
max_seq_len = self.config.max_seq_len
prompt_len = len(prompt_tokens)
total_len = min(max_seq_len, self.max_new_tokens + prompt_len)
- # Output tokens buffer
tokens = torch.full(
(self.batch_size, total_len), -1, dtype=torch.long, device=self.default_device
)
@@ -263,17 +291,14 @@ def _generate_with_mtp(
prefill_time_list = []
decode_time_list = []
- decode_accepted_counts = [] # Only track decode phase for statistics
- cur_pos = 0 # Current position in the output sequence
+ decode_accepted_counts = []
+ cur_pos = 0
- # Prefill phase: process prompt tokens in non-overlapping chunks.
- # Each chunk fills unique KV cache positions for both main model and MTP[0].
while cur_pos < prompt_len - 1:
draft_end = min(cur_pos + self.mtp_seq_len, prompt_len)
draft_tokens = tokens[0, cur_pos:draft_end].clone()
actual_token_count = draft_tokens.shape[0]
- # Pad if needed (use last token for padding)
if actual_token_count < self.mtp_seq_len:
pad_token = draft_tokens[-1].item()
padding = torch.full(
@@ -286,18 +311,13 @@ def _generate_with_mtp(
draft_tokens = draft_tokens.reshape(1, self.mtp_seq_len).to(torch.int32)
- # Provide the extra token for MTP[0]'s shifted input last position.
- # MTP[0] needs tokens[cur_pos+1 : cur_pos+mtp_seq_len+1], so the
- # extra token is at cur_pos + mtp_seq_len.
mtp_extra_pos = cur_pos + self.mtp_seq_len
if mtp_extra_pos < prompt_len:
mtp_extra_token = int(tokens[0, mtp_extra_pos].item())
else:
- # Beyond prompt — use last valid draft token as padding
mtp_extra_token = int(tokens[0, draft_end - 1].item())
self.decode_layer.set_prefill_mtp_extra_token(mtp_extra_token)
- # Tell GPU how many tokens are valid (for cur_pos advancement)
self.decode_layer.set_prefill_valid_tokens(actual_token_count)
start_time = time.time()
@@ -305,27 +325,16 @@ def _generate_with_mtp(
end_time = time.time()
prefill_time_list.append(end_time - start_time)
- # No overlap: advance by the full actual_token_count
cur_pos += actual_token_count
- # After no-overlap prefill, cur_pos may have overshot to prompt_len.
- # Reset to prompt_len - 1 for correct decode start (first decode
- # reprocesses the last prompt token position).
cur_pos = prompt_len - 1
self.set_cur_pos(prompt_len - 1)
- # Decode phase: speculative decoding
- # Set prefill_valid_tokens to 0 to switch to decode mode
self.decode_layer.set_prefill_valid_tokens(0)
finished = False
while cur_pos < total_len - 1 and not finished:
- # Get next_draft_tokens from previous iteration
- # (or use last prompt tokens for first decode)
if cur_pos == prompt_len - 1:
- # First decode iteration: use last prompt token repeated as placeholder drafts
- # We can't use [t6, t7, t8, t9] because that would apply wrong RoPE positions
- # (cur_pos=9 means positions 9,10,11,12, but t6 should be at position 6)
last_token = tokens[0, prompt_len - 1].item()
draft_tokens = torch.full(
(self.mtp_seq_len,),
@@ -335,7 +344,6 @@ def _generate_with_mtp(
)
draft_tokens = draft_tokens.reshape(1, self.mtp_seq_len).to(torch.int32)
else:
- # Use next_draft_tokens from previous iteration
draft_tokens = self.decode_layer.get_next_draft_tokens(0).reshape(
1, self.mtp_seq_len
)
@@ -346,11 +354,9 @@ def _generate_with_mtp(
decode_time_list.append(end_time - start_time)
num_accepted = self.decode_layer.get_num_accepted(0)
- # Use predicted_tokens for output (not next_draft_tokens which is for next iteration)
predicted_tokens = self.decode_layer.get_predicted_tokens(0).flatten()
decode_accepted_counts.append(num_accepted)
- # Add accepted tokens to output
num_output_tokens = num_accepted
for i in range(num_output_tokens):
if cur_pos + 1 + i >= total_len:
@@ -358,12 +364,10 @@ def _generate_with_mtp(
new_token = int(predicted_tokens[i].item())
tokens[0, cur_pos + 1 + i] = new_token
- # Print generated token
if cur_pos + 1 + i >= prompt_len and print_log:
decoded_text = self.tokenizer.decode([new_token], skip_special_tokens=True)
print(decoded_text, end="", flush=True)
- # Check for EOS
if new_token == self.eos_id:
finished = True
break
@@ -384,7 +388,6 @@ def _generate_with_mtp(
f"min={min_accepted}, max={max_accepted}"
)
- # Calculate correct TPS accounting for MTP's multiple tokens per call
if decode_time_list:
total_decode_time = sum(decode_time_list)
effective_tps = total_tokens / total_decode_time if total_decode_time > 0 else 0
@@ -394,14 +397,11 @@ def _generate_with_mtp(
print("\n")
- # Reset sequence after generation
self.decode_layer.reset_sequence()
- # Extract completion tokens
completion_tokens = []
for _, toks in enumerate(tokens.tolist()):
toks = toks[prompt_len : prompt_len + self.max_new_tokens]
- # Remove -1 padding and tokens after EOS
toks = [t for t in toks if t != -1]
if self.eos_id in toks:
toks = toks[: toks.index(self.eos_id)]
@@ -413,6 +413,7 @@ def _generate_with_mtp(
f"{decoded_tokens[0]}\n" if decoded_tokens else "",
decode_time_list,
decode_accepted_counts,
+ prompt_len,
)
def inject_cache(
@@ -452,7 +453,6 @@ def inject_cache(
logger.warning("inject_cache called with empty layer_caches")
return
- # Infer seqlen from first tensor if end_pos not specified
first_ki, _, _ = layer_caches[0]
seqlen = first_ki.size(0)
if end_pos is None:
@@ -473,9 +473,6 @@ def inject_cache(
base_idx = layer_id * 3
- # Copy to device and inject into cache
- # Cache layout: [batch=1, max_seq_len, dim]
- # External data: [seqlen, dim]
ki_src = ki[:cache_len].to(f"cuda:{device_id}")
kv_src = kv[:cache_len].to(f"cuda:{device_id}")
pe_src = pe[:cache_len].to(f"cuda:{device_id}")
@@ -487,14 +484,11 @@ def inject_cache(
logger.info(f"Cache injection completed for {num_devices} devices")
def set_cur_pos(self, cur_pos: int) -> None:
- """Set the current position for RoPE in C++ backend.
-
- This should be called after inject_cache() to ensure the C++ global
- g_cur_pos matches the injected cache length. This is critical for
- correct RoPE position encoding during continued generation.
+ """Set the current position for RoPE.
- For MTP mode, sets the GPU tensor at intermediates[31] directly.
- For non-MTP mode, calls the C++ dsa_show_hands_set_cur_pos API.
+ This should be called after inject_cache() to ensure the runtime position
+ matches the injected cache length, for correct RoPE position encoding
+ during continued generation.
Args:
cur_pos: The current sequence position (typically the length of prefilled tokens).
@@ -505,22 +499,19 @@ def set_cur_pos(self, cur_pos: int) -> None:
>>> # Now generate continues from the correct position
"""
if self.with_mtp:
- # MTP E2E uses g_cur_pos_tensors which is the GPU tensor
num_devices = self.decode_layer.num_devices
for device_id in range(num_devices):
intermediates, _, _, _ = self.decode_layer._get_device_result(device_id)
cur_pos_tensor = intermediates[Idx.CUR_POS]
cur_pos_tensor.fill_(cur_pos)
else:
- # Non-MTP uses the C++ global g_cur_pos
torch.ops.tilert.dsa_show_hands_set_cur_pos(cur_pos)
def inject_last_hidden_state(self, last_hidden_state: torch.Tensor) -> None:
"""Inject the last hidden state for MTP mode.
For MTP (Multi-Token Prediction), the MTP preprocess layer needs the
- last hidden state from the main model's last token. This method injects
- the hidden state into intermediates[33] (last_hidden_states slot).
+ last hidden state from the main model's last token.
Args:
last_hidden_state: [hidden_size] or [1, hidden_size] BF16 tensor.
@@ -535,14 +526,12 @@ def inject_last_hidden_state(self, last_hidden_state: torch.Tensor) -> None:
logger.warning("inject_last_hidden_state called but with_mtp is False, skipping")
return
- # Normalize shape to [1, hidden_size]
if last_hidden_state.dim() == 1:
last_hidden_state = last_hidden_state.unsqueeze(0)
num_devices = self.decode_layer.num_devices
for device_id in range(num_devices):
intermediates, _, _, _ = self.decode_layer._get_device_result(device_id)
- # Shape: [batch=1, seq=4, hidden_size], we set seq[0] since it's the last token
lhs_tensor = intermediates[Idx.LAST_HIDDEN_STATES]
lhs_src = last_hidden_state.to(f"cuda:{device_id}")
lhs_tensor[0, 0, :].copy_(lhs_src.squeeze(0))
diff --git a/python/models/deepseek_v3_2/model_args.py b/tilert/models/deepseek_v3_2/model_args.py
similarity index 92%
rename from python/models/deepseek_v3_2/model_args.py
rename to tilert/models/deepseek_v3_2/model_args.py
index b149edf..441b684 100644
--- a/python/models/deepseek_v3_2/model_args.py
+++ b/tilert/models/deepseek_v3_2/model_args.py
@@ -50,8 +50,8 @@ class ModelArgs:
arch_name = "deepseek_v3_2"
- max_batch_size: int = 1 # NOTE: the current implementation only supports a batch size being 1
- max_seq_len: int = 160 * 1024 # 160K
+ max_batch_size: int = 1
+ max_seq_len: int = 160 * 1024
dtype: Literal["bf16", "fp8"] = "fp8"
scale_fmt: str | None = None
@@ -63,23 +63,20 @@ class ModelArgs:
n_dense_layers: int = 3
n_heads: int = 128
- # moe
n_routed_experts: int = 256
n_shared_experts: int = 1
n_activated_experts: int = 8
n_expert_groups: int = 8
n_limited_groups: int = 4
- score_func: Literal["softmax", "sigmoid"] = "softmax"
+ score_func: Literal["softmax", "sigmoid", "sqrtsoftplus"] = "softmax"
route_scale: float = 2.5
- # mla
q_lora_rank: int = 1536
kv_lora_rank: int = 512
qk_nope_head_dim: int = 128
qk_rope_head_dim: int = 64
v_head_dim: int = 128
- # yarn
original_seq_len: int | None = 4096
rope_theta: float = 10000.0
rope_factor: float | None = 40
@@ -87,14 +84,12 @@ class ModelArgs:
beta_slow: int | None = 1
mscale: float = 1.0
- # index
index_n_heads: int = 64
index_head_dim: int = 128
index_topk: int = 2048
kv_cache_pad: int = 8
- # quant
block_size: int = 128
eps: float = 1e-6
diff --git a/python/models/deepseek_v3_2/modules/__init__.py b/tilert/models/deepseek_v3_2/modules/__init__.py
similarity index 100%
rename from python/models/deepseek_v3_2/modules/__init__.py
rename to tilert/models/deepseek_v3_2/modules/__init__.py
diff --git a/python/models/deepseek_v3_2/modules/dsa.py b/tilert/models/deepseek_v3_2/modules/dsa.py
similarity index 70%
rename from python/models/deepseek_v3_2/modules/dsa.py
rename to tilert/models/deepseek_v3_2/modules/dsa.py
index 64116f9..2efe143 100644
--- a/python/models/deepseek_v3_2/modules/dsa.py
+++ b/tilert/models/deepseek_v3_2/modules/dsa.py
@@ -13,17 +13,75 @@
class Dsa(SerializableTileRTModule):
"""DSA module."""
- def __init__(self, model_args: ModelArgs, device_id: int, num_devices: int):
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ device_id: int,
+ num_devices: int,
+ cached_ffn_ops: list | None = None,
+ ):
super().__init__(
model_args=model_args,
device_id=device_id,
num_devices=num_devices,
remove_selected=True,
)
+ from tilert.models.deepseek_v3_2.modules.mla_v2 import (
+ PureMlaV2,
+ SparseSelectMlaV2,
+ )
+
+ mla_cls = SparseSelectMlaV2 if device_id == 0 else PureMlaV2
+ mla_kwargs: dict = {}
+
+ dev = f"cuda:{device_id}"
+ n_peers = num_devices - 1
+ if device_id == 0:
+ self.v2_peer_bufs = torch.zeros(n_peers, dtype=torch.int64, device=dev)
+ self.v2_partial_buf = torch.zeros(
+ model_args.max_batch_size, 4, model_args.dim, dtype=torch.bfloat16, device=dev
+ )
+ mla_kwargs = {
+ "peer_bufs": self.v2_peer_bufs,
+ "partial_buf": self.v2_partial_buf,
+ }
+ else:
+ max_seq_len = getattr(model_args, "num_mtp", 3) + 1
+ topk = model_args.index_topk
+ self.v2_ll_buf = torch.zeros(max_seq_len * topk * 2, dtype=torch.int32, device=dev)
+ mla_kwargs = {"ll_buf": self.v2_ll_buf}
+
+ mla_num_devices: int | None = None
+ if device_id != 0:
+ mla_num_devices = num_devices - 1
+
+ if cached_ffn_ops is not None:
+ assert (
+ len(cached_ffn_ops) == model_args.n_layers
+ ), f"Expected {model_args.n_layers} cached FFN ops, got {len(cached_ffn_ops)}"
for layer_idx in range(model_args.n_layers):
- block_type = MlpBlock if layer_idx < model_args.n_dense_layers else MoeBlock
- block = block_type(model_args=model_args, device_id=device_id, num_devices=num_devices)
+ ffn_op = cached_ffn_ops[layer_idx] if cached_ffn_ops else None
+ if layer_idx < model_args.n_dense_layers:
+ block = MlpBlock(
+ model_args=model_args,
+ device_id=device_id,
+ num_devices=num_devices,
+ mla_cls=mla_cls,
+ mla_num_devices=mla_num_devices,
+ mla_kwargs=mla_kwargs,
+ mlp=ffn_op,
+ )
+ else:
+ block = MoeBlock(
+ model_args=model_args,
+ device_id=device_id,
+ num_devices=num_devices,
+ mla_cls=mla_cls,
+ mla_num_devices=mla_num_devices,
+ mla_kwargs=mla_kwargs,
+ moe=ffn_op,
+ )
self.register_op(block, prefix=f"layer_{layer_idx}_", suffix=f"_dev_{device_id}")
self.register_op(
@@ -64,7 +122,17 @@ def get_temp_vars(
q_lora_rank = self.model_args.q_lora_rank
kv_lora_rank = self.model_args.kv_lora_rank
qk_nope_head_dim = self.model_args.qk_nope_head_dim
- n_local_heads = self.model_args.n_heads // self.num_devices
+ if self.device_id != 0:
+ from tilert.models.deepseek_v3_2.ops.rmsnorm_projq_wqb import (
+ RmsnormProjqWqbWeightsConverter,
+ )
+
+ qk_head_dim = self.model_args.qk_nope_head_dim + self.model_args.qk_rope_head_dim
+ n_local_heads = RmsnormProjqWqbWeightsConverter._compute_n_local_heads(
+ self.model_args.n_heads, self.num_devices - 1, qk_head_dim
+ )
+ else:
+ n_local_heads = self.model_args.n_heads // self.num_devices
qk_rope_head_dim = self.model_args.qk_rope_head_dim
index_head_dim = self.model_args.index_head_dim
v_head_dim = self.model_args.v_head_dim
@@ -132,8 +200,7 @@ def get_temp_vars(
)
temp_vars[Idx.MOE_UP_GATE] = torch.zeros_like(exp_up_gate)
- # temp_vars[Idx.IDX_SEL_WS] = torch.zeros(*batch_seq, 4, index_topk * 2, **int32_desc)
- temp_vars[Idx.IDX_SEL_WS] = torch.zeros(*batch_seq, (200 * 1024 + 258), **int32_desc)
+ temp_vars[Idx.IDX_SEL_WS] = torch.zeros(*batch_seq, (200 * 1024 + 260), **int32_desc)
temp_vars[Idx.MTP0_TOKEN_OUT] = torch.zeros(*batch_seq, 1, **int32_desc)
temp_vars[Idx.MTP1_TOKEN_OUT] = torch.zeros(*batch_seq, 1, **int32_desc)
@@ -147,6 +214,14 @@ def get_temp_vars(
temp_vars[Idx.TOP_P_SCORES] = torch.zeros(*batch_seq, **fp32_desc)
temp_vars[Idx.TOP_P_DEBUG] = torch.zeros(*batch_seq, vocab_size, **fp32_desc)
+ temp_vars[Idx.LORA_SLOT_ID] = torch.zeros(1, **int32_desc)
+ temp_vars[Idx.LORA_RANK] = torch.zeros(1, **int32_desc)
+
+ max_top_n = 256
+ temp_vars[Idx.TOP_N_LOG_PROBS] = torch.zeros(*batch_seq, max_top_n, **fp32_desc)
+ temp_vars[Idx.TOP_N_INDICES] = torch.zeros(*batch_seq, max_top_n, **int32_desc)
+ temp_vars[Idx.LOGPROBS_FLAG] = torch.zeros(1, **int32_desc)
+
for i, t in enumerate(temp_vars):
if t is None:
raise RuntimeError(f"temp_vars[{i}] ({Idx(i).name}) was not initialized")
diff --git a/python/models/deepseek_v3_2/modules/end2end.py b/tilert/models/deepseek_v3_2/modules/end2end.py
similarity index 68%
rename from python/models/deepseek_v3_2/modules/end2end.py
rename to tilert/models/deepseek_v3_2/modules/end2end.py
index 47a5671..e1be82e 100644
--- a/python/models/deepseek_v3_2/modules/end2end.py
+++ b/tilert/models/deepseek_v3_2/modules/end2end.py
@@ -8,9 +8,11 @@
from typing import Any
import torch
+from safetensors import safe_open
from safetensors.torch import load_file
from tilert import logger
+from tilert.models.base import TileRTModule
from tilert.models.deepseek_v3_2.model_args import ModelArgs
from tilert.models.deepseek_v3_2.modules.dsa import Dsa
from tilert.models.deepseek_v3_2.modules.mtp import MTP
@@ -18,12 +20,62 @@
from tilert.models.utils import precompute_freqs_cis
from tilert.utils import get_profile_log_tensor
-__all__ = ["ShowHandsDSALayer"]
+__all__ = ["ShowHandsDSALayer", "_extract_ffn_ops", "_get_moe_weight_keys"]
DeviceResult = tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor], torch.Tensor]
+def _mark_weights_initialized(module: TileRTModule) -> None:
+ """Recursively mark a module and all sub-ops as having initialized tilert weights."""
+ module.is_tilert_weights_init = True
+ if hasattr(module, "exec_seq"):
+ for op in module.exec_seq:
+ _mark_weights_initialized(op)
+
+
+def _extract_ffn_ops(dsa: "Dsa") -> list:
+ """Extract Moe/Mlp op objects from a Dsa's layer blocks.
+
+ Returns a list of length n_layers where each element is a Moe or Mlp instance.
+ """
+ from tilert.models.deepseek_v3_2.modules.mlp import MlpBlock
+ from tilert.models.deepseek_v3_2.modules.moe import MoeBlock
+
+ ffn_ops = []
+ for block in dsa.exec_seq:
+ if isinstance(block, MoeBlock):
+ op = block.moe
+ _mark_weights_initialized(op)
+ ffn_ops.append(op)
+ elif isinstance(block, MlpBlock):
+ op = block.mlp
+ _mark_weights_initialized(op)
+ ffn_ops.append(op)
+
+ assert (
+ len(ffn_ops) == dsa.model_args.n_layers
+ ), f"Expected {dsa.model_args.n_layers} FFN ops, got {len(ffn_ops)}"
+ return ffn_ops
+
+
+def _get_moe_weight_keys(dsa: "Dsa") -> set[str]:
+ """Get state_dict keys that belong exclusively to MOE/MLP ops in this Dsa."""
+ from tilert.models.deepseek_v3_2.modules.mlp import MlpBlock
+ from tilert.models.deepseek_v3_2.modules.moe import MoeBlock
+
+ moe_keys: set[str] = set()
+ mla_keys: set[str] = set()
+ for block, prefix, suffix in zip(dsa.exec_seq, dsa.prefix_seq, dsa.suffix_seq):
+ if isinstance(block, (MoeBlock, MlpBlock)):
+ ffn = block.moe if isinstance(block, MoeBlock) else block.mlp
+ for alias in ffn.get_tilert_weights_alias():
+ moe_keys.add(f"{prefix}{alias}{suffix}")
+ for alias in block.mla.get_tilert_weights_alias():
+ mla_keys.add(f"{prefix}{alias}{suffix}")
+ return moe_keys - mla_keys
+
+
def dsa_show_hands_prepare_money(
params: list[torch.Tensor],
temp_vars: list[torch.Tensor],
@@ -102,9 +154,6 @@ def dsa_mtp_e2e_show_hands_set_prefill_valid_tokens(
def dsa_mtp_e2e_show_hands_set_prefill_mtp_extra_token(token: int, is_glm5: bool = False) -> Any:
"""Set the extra token for MTP[0] shifted input during prefill.
- This is the prompt token at (cur_pos + mtp_seq_len), used as the last position
- of MTP[0]'s shifted input to enable no-overlap prefill chunking.
-
Args:
token: The extra prompt token id (int32).
"""
@@ -144,6 +193,7 @@ def __init__(
self.with_mtp = with_mtp
self.multi_devices_results: list[DeviceResult | None] = [None] * torch.cuda.device_count()
+ self._dsa_objects: list[Dsa | None] = [None] * torch.cuda.device_count()
self.temperature = temperature
self.top_p = top_p
@@ -155,7 +205,11 @@ def _gen_freqs_cis(self) -> torch.Tensor:
return torch.view_as_real(freqs_cis).reshape(freqs_cis.shape[0], -1)
def load_device_weights(
- self, model_path: str, device_id: int, extra_keys: list
+ self,
+ model_path: str,
+ device_id: int,
+ extra_keys: list,
+ skip_keys: set[str] | None = None,
) -> dict[str, torch.Tensor]:
index_file = "model.safetensors.index.json"
with open(os.path.join(model_path, index_file), encoding="utf-8") as f:
@@ -165,20 +219,33 @@ def load_device_weights(
weights_list = [_k for _k in weight_file_map.keys() if _k.endswith(f"dev_{device_id}")]
weights_list = [*weights_list, *extra_keys]
+ if skip_keys:
+ weights_list = [k for k in weights_list if k not in skip_keys]
+
target_files = set()
for weight_key in weights_list:
weight_file = weight_file_map[weight_key]
target_files.add(weight_file)
state_dicts = {}
+ weights_set = set(weights_list)
for weight_file in target_files:
- logger.info(f"Loading weights from {weight_file} for device {device_id}")
- state_dict = load_file(
- os.path.join(model_path, weight_file), device=f"cuda:{device_id}"
- )
- state_dicts.update(state_dict)
- del state_dict
- torch.cuda.empty_cache()
+ filepath = os.path.join(model_path, weight_file)
+ if skip_keys:
+ logger.info(
+ f"Selectively loading weights from {weight_file} for device {device_id}"
+ )
+ with safe_open(filepath, framework="pt", device=f"cuda:{device_id}") as f:
+ for key in f.keys():
+ if key in weights_set:
+ state_dicts[key] = f.get_tensor(key)
+ torch.cuda.empty_cache()
+ else:
+ logger.info(f"Loading weights from {weight_file} for device {device_id}")
+ state_dict = load_file(filepath, device=f"cuda:{device_id}")
+ state_dicts.update(state_dict)
+ del state_dict
+ torch.cuda.empty_cache()
state_dicts["freqs_cis"] = self._gen_freqs_cis().to(device_id)
return state_dicts
@@ -186,11 +253,7 @@ def load_device_weights(
def update_sampling_config(
self, temperature: float, top_p: float, top_k: int, use_topp: bool = True
) -> None:
- """Update sampling config, re-capturing CUDA graphs if parameters changed.
-
- Sampling parameters are baked into CUDA graph instructions at prepare_money
- time, so any change requires a full teardown + re-capture cycle.
- """
+ """Update sampling config, re-capturing CUDA graphs if parameters changed."""
new_config = (temperature, top_p, top_k, use_topp)
current_config = (self.temperature, self.top_p, self.top_k, self.use_topp)
if new_config == current_config:
@@ -201,20 +264,17 @@ def update_sampling_config(
f"temperature={temperature}, top_p={top_p}, top_k={top_k}, use_topp={use_topp}"
)
- # Teardown: stop all threads and unregister all modules
if self.with_mtp:
dsa_show_hands_go_home(True, self.is_glm5)
dsa_show_hands_go_home(False, self.is_glm5)
else:
dsa_show_hands_go_home(False, self.is_glm5)
- # Store new config
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.use_topp = use_topp
- # Update sampling_config tensor on all devices
for device_id in range(self.num_devices):
result = self.multi_devices_results[device_id]
if result is not None:
@@ -227,7 +287,6 @@ def update_sampling_config(
)
)
- # Re-prepare all modules (re-captures CUDA graphs with new config)
for device_id in range(self.num_devices):
with torch.cuda.device(device_id):
intermediates, caches, params, profile_logs = self._get_device_result(device_id)
@@ -274,8 +333,23 @@ def generate_params_with_continuous_storage(
offset += aligned_param_size
return cloned_params
- def _init_weights(self, model_path: str | None) -> None:
- """Load the model weights from the given path or generate random weights."""
+ def _init_weights(
+ self,
+ model_path: str | None,
+ cached_ffn_ops_per_device: dict[int, list] | None = None,
+ skip_keys_per_device: dict[int, set[str]] | None = None,
+ ) -> None:
+ """Load the model weights from the given path or generate random weights.
+
+ Args:
+ model_path: Path to the model weights directory.
+ cached_ffn_ops_per_device: Optional dict mapping device_id to cached FFN ops.
+ When provided, these ops are injected into the Dsa and their weights
+ are not re-loaded from disk.
+ skip_keys_per_device: Optional dict mapping device_id to safetensors keys
+ to skip during loading. Used together with cached_ffn_ops_per_device.
+ """
+ self._v2_p2p: dict = {}
def __load_weights(device_id: int, model_path: str | None) -> None:
intermediates: list[torch.Tensor] = []
@@ -284,8 +358,12 @@ def __load_weights(device_id: int, model_path: str | None) -> None:
state_dicts = {}
start_time = time.time()
with torch.cuda.device(device_id):
- assert model_path is not None # Type narrowing for mypy
- # state_dicts = _load_state_dicts(model_path, dev_attrs)
+ assert model_path is not None
+ skip_keys = (
+ skip_keys_per_device.get(device_id)
+ if skip_keys_per_device is not None
+ else None
+ )
state_dicts = self.load_device_weights(
model_path,
device_id,
@@ -294,12 +372,33 @@ def __load_weights(device_id: int, model_path: str | None) -> None:
f"layer_{self.model_args.n_layers}_lm_head.weight_dev_{device_id}",
f"layer_{self.model_args.n_layers}_model.norm.weight_dev_{device_id}",
],
+ skip_keys=skip_keys,
)
- dsa = Dsa(self.model_args, device_id, self.num_devices)
+ cached_ffn_ops = (
+ cached_ffn_ops_per_device.get(device_id)
+ if cached_ffn_ops_per_device is not None
+ else None
+ )
+ dsa = Dsa(
+ self.model_args,
+ device_id,
+ self.num_devices,
+ cached_ffn_ops=cached_ffn_ops,
+ )
dsa.init_tilert_weights(state_dicts)
+ self._dsa_objects[device_id] = dsa
params.extend(dsa.get_weights_list())
caches.extend(dsa.get_cache_vars())
+
+ if device_id == 0:
+ self._v2_p2p[device_id] = {
+ "peer_bufs": dsa.v2_peer_bufs,
+ }
+ else:
+ self._v2_p2p[device_id] = {
+ "ll_buf": dsa.v2_ll_buf,
+ }
intermediates.extend(
self.generate_params_with_continuous_storage(
dsa.get_temp_vars(
@@ -316,8 +415,6 @@ def __load_weights(device_id: int, model_path: str | None) -> None:
)
)
- # generate_params_with_continuous_storage creates zero-filled views.
- # Populate sampling_config with actual values.
sampling_config = intermediates[Idx.SAMPLING_CONFIG]
sampling_config.copy_(
torch.tensor(
@@ -325,20 +422,32 @@ def __load_weights(device_id: int, model_path: str | None) -> None:
self.temperature,
self.top_p,
float(self.top_k),
- 1.0 if self.use_topp else 0.0, # 0=top1(default), 1=topp
+ 1.0 if self.use_topp else 0.0,
],
dtype=torch.float32,
device=device_id,
)
)
- # Track base (non-MTP) params/caches count for dual-module init
base_params_count = len(params)
base_caches_count = len(caches)
- # Add MTP-specific params when with_mtp is True
if self.with_mtp:
- mtp = MTP(self.model_args, device_id, self.num_devices)
+ from tilert.models.deepseek_v3_2.modules.mla_v2 import (
+ PureMlaV2,
+ SparseSelectMlaV2,
+ )
+
+ mtp_kwargs: dict = {}
+ mtp_kwargs["mla_cls"] = SparseSelectMlaV2 if device_id == 0 else PureMlaV2
+ mtp_kwargs["mla_num_devices"] = 1 if device_id == 0 else self.num_devices - 1
+ if device_id == 0:
+ mtp_kwargs["mla_kwargs"] = {
+ "peer_bufs": dsa.v2_peer_bufs,
+ }
+ else:
+ mtp_kwargs["mla_kwargs"] = {"ll_buf": dsa.v2_ll_buf}
+ mtp = MTP(self.model_args, device_id, self.num_devices, **mtp_kwargs)
mtp.init_tilert_weights(state_dicts)
params.extend(mtp.get_weights_list())
caches.extend(mtp.get_cache_vars())
@@ -379,11 +488,21 @@ def _runner(dev_id: int) -> None:
if exc is not None:
raise RuntimeError(f"Failed to initialize device {device_id}: {exc}") from exc
- # Prepare money for all devices
+ if self._v2_p2p:
+ gpu0 = self._v2_p2p[0]
+ peer_bufs_cpu = torch.zeros(self.num_devices - 1, dtype=torch.int64)
+ for i in range(self.num_devices - 1):
+ dev_id = i + 1
+ peer_bufs_cpu[i] = self._v2_p2p[dev_id]["ll_buf"].data_ptr()
+ gpu0["peer_bufs"].copy_(peer_bufs_cpu)
+ logger.info(
+ "V2 P2P exchange complete: peer_bufs (ll_buf)=%s",
+ [hex(int(x)) for x in peer_bufs_cpu],
+ )
+
for device_id in range(self.num_devices):
with torch.cuda.device(device_id):
intermediates, caches, params, profile_logs = self._get_device_result(device_id)
- # Always prepare the primary module (MTP if with_mtp, else non-MTP)
dsa_show_hands_prepare_money(
params,
intermediates,
@@ -393,7 +512,6 @@ def _runner(dev_id: int) -> None:
self.with_mtp,
self.is_glm5,
)
- # When MTP-capable, also prepare the non-MTP module using base params/caches
if self.with_mtp:
dsa_show_hands_prepare_money(
params[: self._base_params_count],
@@ -411,6 +529,21 @@ def from_pretrained(self, model_path: str) -> None:
raise ValueError(f"Model weights directory {model_path} does not exist")
self._init_weights(model_path)
+ def from_pretrained_with_cache(
+ self,
+ model_path: str,
+ cached_ffn_ops_per_device: dict[int, list],
+ skip_keys_per_device: dict[int, set[str]],
+ ) -> None:
+ """Load weights with cached MOE/MLP ops."""
+ if not os.path.exists(model_path):
+ raise ValueError(f"Model weights directory {model_path} does not exist")
+ self._init_weights(
+ model_path,
+ cached_ffn_ops_per_device=cached_ffn_ops_per_device,
+ skip_keys_per_device=skip_keys_per_device,
+ )
+
def init_random_weights(self) -> None:
"""Generate random weights."""
self._init_weights(None)
@@ -438,7 +571,6 @@ def set_sampling_seed(self, seed: int, with_mtp: bool | None = None) -> None:
def reset_sequence(self) -> None:
if self.with_mtp:
- # Reset both MTP and non-MTP modules for clean state
dsa_show_hands_reset(True, self.is_glm5)
dsa_show_hands_reset(False, self.is_glm5)
else:
@@ -446,7 +578,6 @@ def reset_sequence(self) -> None:
def cleanup(self) -> None:
if self.with_mtp:
- # Cleanup both MTP and non-MTP modules
dsa_show_hands_go_home(True, self.is_glm5)
dsa_show_hands_go_home(False, self.is_glm5)
else:
@@ -518,3 +649,55 @@ def get_predicted_tokens(self, device_id: int = 0) -> torch.Tensor:
"""
intermediates, _, _, _ = self._get_device_result(device_id)
return intermediates[Idx.PREDICTED_TOKENS]
+
+ def get_logits(self, device_id: int = 0) -> torch.Tensor:
+ """Get logits from the specified device.
+
+ Args:
+ device_id: Device ID to get results from.
+
+ Returns:
+ Logits tensor of shape [batch, seq_len, vocab_size] (FP32).
+ """
+ intermediates, _, _, _ = self._get_device_result(device_id)
+ return intermediates[Idx.LOGITS_OUT]
+
+ def get_top_n_logprobs(self, device_id: int = 0) -> tuple[torch.Tensor, torch.Tensor]:
+ """Get top-N log-probabilities and token IDs from the top_p kernel.
+
+ Args:
+ device_id: Device ID to get results from.
+
+ Returns:
+ Tuple of (log_probs, token_ids):
+ - log_probs: [batch, seq_len, 256] FP32
+ - token_ids: [batch, seq_len, 256] INT32
+ """
+ intermediates, _, _, _ = self._get_device_result(device_id)
+ return (
+ intermediates[Idx.TOP_N_LOG_PROBS],
+ intermediates[Idx.TOP_N_INDICES],
+ )
+
+ def get_token_logprob(self, device_id: int = 0) -> torch.Tensor:
+ """Get log-probability of the sampled token (from TOP_P_SCORES).
+
+ Args:
+ device_id: Device ID to get results from.
+
+ Returns:
+ Tensor of shape [batch, seq_len] (FP32).
+ """
+ intermediates, _, _, _ = self._get_device_result(device_id)
+ return intermediates[Idx.TOP_P_SCORES]
+
+ def set_logprobs_enabled(self, enabled: bool) -> None:
+ """Enable or disable logprobs export in the top_p kernel.
+
+ Args:
+ enabled: True to enable logprobs export, False to disable.
+ """
+ flag_val = 1 if enabled else 0
+ for device_id in range(self.num_devices):
+ intermediates, _, _, _ = self._get_device_result(device_id)
+ intermediates[Idx.LOGPROBS_FLAG].fill_(flag_val)
diff --git a/tilert/models/deepseek_v3_2/modules/mla_v2.py b/tilert/models/deepseek_v3_2/modules/mla_v2.py
new file mode 100644
index 0000000..bbf4f14
--- /dev/null
+++ b/tilert/models/deepseek_v3_2/modules/mla_v2.py
@@ -0,0 +1,248 @@
+"""MLA weight generator classes for device-group-specific pipelines."""
+
+import torch
+
+from tilert.models.base import SerializableTileRTModule
+from tilert.models.deepseek_v3_2.model_args import ModelArgs
+from tilert.models.deepseek_v3_2.ops.layernorm_rope_rotate import LayerNormRoPERotate
+from tilert.models.deepseek_v3_2.ops.projo_wkvb import ProjoWKVb
+from tilert.models.deepseek_v3_2.ops.projq_wqb import ProjqWqb
+from tilert.models.deepseek_v3_2.ops.projx_wis import ProjxWis
+from tilert.models.deepseek_v3_2.ops.rmsnorm_kv import KVRMSNorm
+from tilert.models.deepseek_v3_2.ops.rmsnorm_projq_wqb import (
+ RmsnormProjqWqb,
+ RmsnormProjqWqbAlgorithm,
+)
+from tilert.models.deepseek_v3_2.ops.rmsnorm_projq_wqi import (
+ RmsnormProjqWqi,
+ RmsnormProjqWqiAlgorithm,
+)
+from tilert.models.deepseek_v3_2.ops.rmsnorm_projx_wqakis import (
+ RMSNormProjxWqakis,
+)
+from tilert.models.deepseek_v3_2.ops.rmsnorm_projx_wqkva import (
+ RMSNormProjxWqkva,
+ RMSNormProjxWqkvaAlgorithm,
+)
+from tilert.models.deepseek_v3_2.ops.unproj_o_allreduce import (
+ UnProjOAllReduce,
+ UnProjOAllReduceAlgorithm,
+)
+
+
+class SparseSelectMlaV2(SerializableTileRTModule):
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ device_id: int,
+ num_devices: int,
+ peer_bufs: torch.Tensor | None = None,
+ partial_buf: torch.Tensor | None = None,
+ ):
+ super().__init__(model_args=model_args, device_id=device_id, num_devices=num_devices)
+
+ self.rmsnorm_projx_wqakis = RMSNormProjxWqakis(
+ model_args=model_args, device_id=device_id, num_devices=num_devices
+ )
+ self.register_op(self.rmsnorm_projx_wqakis)
+
+ self.rmsnorm_projq_wqi = RmsnormProjqWqi(
+ model_args=model_args, device_id=device_id, num_devices=num_devices
+ )
+ self.rmsnorm_projq_wqi.algorithm = RmsnormProjqWqiAlgorithm.BF16MMA
+ self.register_op(self.rmsnorm_projq_wqi)
+
+ self.layernorm_rope_rotate = LayerNormRoPERotate(
+ model_args=model_args, device_id=device_id, num_devices=num_devices
+ )
+ self.register_op(self.layernorm_rope_rotate)
+
+ self.projx_wis = ProjxWis(
+ model_args=model_args, device_id=device_id, num_devices=num_devices
+ )
+ self.register_op(self.projx_wis)
+
+ self.peer_bufs = peer_bufs
+ self.partial_buf = partial_buf
+
+ self.ki_cache: torch.Tensor | None = None
+ self.kv_cache: torch.Tensor | None = None
+ self.pe_cache: torch.Tensor | None = None
+
+ def get_weights_list(self) -> list[torch.Tensor]:
+ """Return weight tensors."""
+ weights = super().get_weights_list()
+
+ dev = f"cuda:{self.device_id}"
+ if self.peer_bufs is None:
+ self.peer_bufs = torch.zeros(self.num_devices - 1, dtype=torch.int64, device=dev)
+ if self.partial_buf is None:
+ self.partial_buf = torch.zeros(
+ self.model_args.max_batch_size,
+ 4,
+ self.model_args.dim,
+ dtype=torch.bfloat16,
+ device=dev,
+ )
+
+ weights.append(self.peer_bufs)
+ weights.append(self.partial_buf)
+
+ return weights
+
+ def get_cache_vars(self) -> list[torch.Tensor]:
+ """Return [ki_cache, kv_cache, pe_cache] matching DsaCacheVars layout."""
+ cache_seq_len = self.model_args.max_seq_len + self.model_args.kv_cache_pad
+ bs_args = (self.model_args.max_batch_size, cache_seq_len)
+
+ if self.ki_cache is None:
+ ki_dim = self.model_args.index_head_dim
+ self.ki_cache = torch.zeros(
+ *bs_args, ki_dim, dtype=torch.bfloat16, device=f"cuda:{self.device_id}"
+ )
+ if self.kv_cache is None:
+ kv_dim = self.model_args.kv_lora_rank
+ self.kv_cache = torch.zeros(
+ *bs_args, kv_dim, dtype=torch.bfloat16, device=f"cuda:{self.device_id}"
+ )
+ if self.pe_cache is None:
+ pe_dim = self.model_args.qk_rope_head_dim
+ self.pe_cache = torch.zeros(
+ *bs_args, pe_dim, dtype=torch.bfloat16, device=f"cuda:{self.device_id}"
+ )
+ return [*super().get_cache_vars(), self.ki_cache, self.kv_cache, self.pe_cache]
+
+
+class PureMlaV2(SerializableTileRTModule):
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ device_id: int,
+ num_devices: int,
+ ll_buf: torch.Tensor | None = None,
+ ):
+ super().__init__(model_args=model_args, device_id=device_id, num_devices=num_devices)
+
+ self.rmsnorm_projx_wqkva = RMSNormProjxWqkva(
+ model_args=model_args, device_id=device_id, num_devices=num_devices
+ )
+ self.rmsnorm_projx_wqkva.algorithm = RMSNormProjxWqkvaAlgorithm.DECOUPLED
+ self.register_op(self.rmsnorm_projx_wqkva)
+
+ self.rmsnorm_projq_wqb = RmsnormProjqWqb(
+ model_args=model_args, device_id=device_id, num_devices=num_devices
+ )
+ self.rmsnorm_projq_wqb.algorithm = RmsnormProjqWqbAlgorithm.BF16MMA
+ self.register_op(self.rmsnorm_projq_wqb)
+
+ self.rmsnorm_kv = KVRMSNorm(
+ model_args=model_args, device_id=device_id, num_devices=num_devices
+ )
+ self.register_op(self.rmsnorm_kv)
+
+ self.projq_wqb = ProjqWqb(
+ model_args=model_args, device_id=device_id, num_devices=num_devices
+ )
+ self.register_op(self.projq_wqb)
+
+ self.projo_wkvb = ProjoWKVb(
+ model_args=model_args, device_id=device_id, num_devices=num_devices
+ )
+ self.register_op(self.projo_wkvb)
+
+ allreduce_algo = UnProjOAllReduceAlgorithm.BF16MMA
+ self.unproj_o_allreduce = UnProjOAllReduce(
+ model_args=model_args,
+ device_id=device_id,
+ num_devices=num_devices,
+ algorithm=allreduce_algo,
+ )
+ self.register_op(self.unproj_o_allreduce)
+
+ self.ll_buf = ll_buf
+
+ self.ki_cache: torch.Tensor | None = None
+ self.kv_cache: torch.Tensor | None = None
+ self.pe_cache: torch.Tensor | None = None
+
+ def init_random_weights(self) -> None:
+ """Initialize random weights for this module."""
+ super().init_random_weights()
+
+ from tilert.models.common import init_func
+
+ for op in [self.projq_wqb, self.projo_wkvb]:
+ padded_total = op.num_local_heads * op.num_devices
+ w = init_func(
+ torch.empty(
+ padded_total * op.wkvb_head_dim, op.wkvb_lora_rank, dtype=torch.float8_e4m3fn
+ )
+ )
+ s = init_func(
+ torch.empty(
+ padded_total * op.wkvb_head_dim // op.model_args.block_size,
+ op.wkvb_lora_rank_qsize,
+ dtype=torch.float32,
+ )
+ )
+ ref_dict = dict(zip(op.ref_weights_alias(), [w, s]))
+ op.init_reference_weights(ref_dict)
+ sharded = op.device_sharding(ref_dict)
+ per_dev = {k: v[op.device_id] for k, v in sharded.items()}
+ op.init_tilert_weights_hmma(per_dev)
+
+ def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ """Load TileRT weights for this module from state_dict."""
+ self.projq_wqb.is_tilert_weights_init = True
+ self.projo_wkvb.is_tilert_weights_init = True
+
+ super().init_tilert_weights(state_dict)
+
+ for op in [self.projq_wqb, self.projo_wkvb]:
+ op_state_dict = {}
+ for op_key in op.get_tilert_weights_alias():
+ for p, s in zip(self.prefix_seq, self.suffix_seq):
+ original_key = f"{p}{op_key}{s}"
+ if original_key in state_dict:
+ op_state_dict[op_key] = state_dict[original_key]
+ break
+ op.is_tilert_weights_init = False
+ op.init_tilert_weights_hmma(op_state_dict)
+
+ def get_weights_list(self) -> list[torch.Tensor]:
+ """Return weight tensors."""
+ weights = super().get_weights_list()
+
+ if self.ll_buf is None:
+ max_seq_len = getattr(self.model_args, "num_mtp", 3) + 1
+ topk = self.model_args.index_topk
+ self.ll_buf = torch.zeros(
+ max_seq_len * topk * 2, dtype=torch.int32, device=f"cuda:{self.device_id}"
+ )
+
+ weights.append(self.ll_buf)
+
+ return weights
+
+ def get_cache_vars(self) -> list[torch.Tensor]:
+ """Return [ki_cache, kv_cache, pe_cache] matching DsaCacheVars layout."""
+ cache_seq_len = self.model_args.max_seq_len + self.model_args.kv_cache_pad
+ bs_args = (self.model_args.max_batch_size, cache_seq_len)
+
+ if self.ki_cache is None:
+ ki_dim = self.model_args.index_head_dim
+ self.ki_cache = torch.zeros(
+ *bs_args, ki_dim, dtype=torch.bfloat16, device=f"cuda:{self.device_id}"
+ )
+ if self.kv_cache is None:
+ kv_dim = self.model_args.kv_lora_rank
+ self.kv_cache = torch.zeros(
+ *bs_args, kv_dim, dtype=torch.bfloat16, device=f"cuda:{self.device_id}"
+ )
+ if self.pe_cache is None:
+ pe_dim = self.model_args.qk_rope_head_dim
+ self.pe_cache = torch.zeros(
+ *bs_args, pe_dim, dtype=torch.bfloat16, device=f"cuda:{self.device_id}"
+ )
+ return [*super().get_cache_vars(), self.ki_cache, self.kv_cache, self.pe_cache]
diff --git a/python/models/deepseek_v3_2/modules/mlp.py b/tilert/models/deepseek_v3_2/modules/mlp.py
similarity index 51%
rename from python/models/deepseek_v3_2/modules/mlp.py
rename to tilert/models/deepseek_v3_2/modules/mlp.py
index 1e9a327..217de6a 100644
--- a/python/models/deepseek_v3_2/modules/mlp.py
+++ b/tilert/models/deepseek_v3_2/modules/mlp.py
@@ -1,7 +1,9 @@
from tilert.models.base import SerializableTileRTModule
from tilert.models.deepseek_v3_2.model_args import ModelArgs
-from tilert.models.deepseek_v3_2.modules.mla import Mla
-from tilert.models.deepseek_v3_2.ops.down_allreduce import DownAllReduce
+from tilert.models.deepseek_v3_2.modules.mla_v2 import PureMlaV2 as Mla
+from tilert.models.deepseek_v3_2.ops.down_allreduce import (
+ DownAllReduce,
+)
from tilert.models.deepseek_v3_2.ops.rmsnorm_up_gate_silu import (
RMSNormUpGateSiLU,
RMSNormUpGateSiLUAlgorithm,
@@ -11,7 +13,12 @@
class Mlp(SerializableTileRTModule):
"""Implement the MLP operations."""
- def __init__(self, model_args: ModelArgs, device_id: int, num_devices: int):
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ device_id: int,
+ num_devices: int,
+ ):
super().__init__(model_args=model_args, device_id=device_id, num_devices=num_devices)
self.rmsnorm_mlp_up_gate_silu = RMSNormUpGateSiLU(
@@ -19,9 +26,9 @@ def __init__(self, model_args: ModelArgs, device_id: int, num_devices: int):
device_id=device_id,
num_devices=num_devices,
)
- if model_args.arch_name == "glm_5":
- self.rmsnorm_mlp_up_gate_silu.algorithm = RMSNormUpGateSiLUAlgorithm.FP16MMA
+ self.rmsnorm_mlp_up_gate_silu.algorithm = RMSNormUpGateSiLUAlgorithm.FP16MMA
self.register_op(self.rmsnorm_mlp_up_gate_silu)
+
self.rmsnorm_mlp_down = DownAllReduce(
model_args=model_args, device_id=device_id, num_devices=num_devices
)
@@ -32,7 +39,15 @@ class MlpBlock(SerializableTileRTModule):
"""Implement the MOE block operations."""
def __init__(
- self, model_args: ModelArgs, device_id: int, num_devices: int, remove_selected: bool = False
+ self,
+ model_args: ModelArgs,
+ device_id: int,
+ num_devices: int,
+ remove_selected: bool = False,
+ mla_cls: type | None = None,
+ mla_num_devices: int | None = None,
+ mla_kwargs: dict | None = None,
+ mlp: "Mlp | None" = None,
):
super().__init__(
model_args=model_args,
@@ -41,7 +56,19 @@ def __init__(
remove_selected=remove_selected,
)
- self.mla = Mla(model_args=model_args, device_id=device_id, num_devices=num_devices)
+ mla_class = mla_cls or Mla
+ mla_nd = mla_num_devices if mla_num_devices is not None else num_devices
+ self.mla = mla_class(
+ model_args=model_args, device_id=device_id, num_devices=mla_nd, **(mla_kwargs or {})
+ )
self.register_op(self.mla)
- self.mlp = Mlp(model_args=model_args, device_id=device_id, num_devices=num_devices)
+ self.mlp = (
+ mlp
+ if mlp is not None
+ else Mlp(
+ model_args=model_args,
+ device_id=device_id,
+ num_devices=num_devices,
+ )
+ )
self.register_op(self.mlp)
diff --git a/python/models/deepseek_v3_2/modules/moe.py b/tilert/models/deepseek_v3_2/modules/moe.py
similarity index 53%
rename from python/models/deepseek_v3_2/modules/moe.py
rename to tilert/models/deepseek_v3_2/modules/moe.py
index f343e79..4ea2dff 100644
--- a/python/models/deepseek_v3_2/modules/moe.py
+++ b/tilert/models/deepseek_v3_2/modules/moe.py
@@ -1,17 +1,26 @@
+import torch
+
from tilert.models.base import SerializableTileRTModule
from tilert.models.deepseek_v3_2.model_args import ModelArgs
-from tilert.models.deepseek_v3_2.modules.mla import Mla
-from tilert.models.deepseek_v3_2.ops.expert_down_allreduce import ExpertDownAllReduce
+from tilert.models.deepseek_v3_2.modules.mla_v2 import PureMlaV2 as Mla
+from tilert.models.deepseek_v3_2.ops.expert_down_allreduce import (
+ ExpertDownAllReduce,
+ ExpertDownAllReduceAlgorithm,
+)
from tilert.models.deepseek_v3_2.ops.expert_sel_up_gate_silu import (
ExpertSelectUpGateSiLU,
ExpertSelectUpGateSiLUAlgorithm,
)
-from tilert.models.deepseek_v3_2.ops.rmsnorm_expert_proj import RMSNormExpertProj
+from tilert.models.deepseek_v3_2.ops.rmsnorm_expert_proj import (
+ RMSNormExpertProj,
+)
class Moe(SerializableTileRTModule):
"""Implement the MOE operations."""
+ rmsnorm_expert_proj: RMSNormExpertProj
+
def __init__(self, model_args: ModelArgs, device_id: int, num_devices: int):
super().__init__(model_args=model_args, device_id=device_id, num_devices=num_devices)
@@ -21,22 +30,38 @@ def __init__(self, model_args: ModelArgs, device_id: int, num_devices: int):
self.register_op(self.rmsnorm_expert_proj)
self.exp_sel_up_gate_silu = ExpertSelectUpGateSiLU(
- model_args=model_args, device_id=device_id, num_devices=num_devices
+ model_args=model_args,
+ device_id=device_id,
+ num_devices=num_devices,
+ algorithm=ExpertSelectUpGateSiLUAlgorithm.BF16MMA,
)
- if model_args.arch_name == "glm_5":
- self.exp_sel_up_gate_silu.algorithm = ExpertSelectUpGateSiLUAlgorithm.FP16MMA
self.register_op(self.exp_sel_up_gate_silu)
+
self.expert_down_allreduce = ExpertDownAllReduce(
- model_args=model_args, device_id=device_id, num_devices=num_devices
+ model_args=model_args,
+ device_id=device_id,
+ num_devices=num_devices,
+ algorithm=ExpertDownAllReduceAlgorithm.BF16MMA,
)
self.register_op(self.expert_down_allreduce)
+ def get_weights_list(self) -> list[torch.Tensor]:
+ return super().get_weights_list()
+
class MoeBlock(SerializableTileRTModule):
"""Implement the MOE block operations."""
def __init__(
- self, model_args: ModelArgs, device_id: int, num_devices: int, remove_selected: bool = False
+ self,
+ model_args: ModelArgs,
+ device_id: int,
+ num_devices: int,
+ remove_selected: bool = False,
+ mla_cls: type | None = None,
+ mla_num_devices: int | None = None,
+ mla_kwargs: dict | None = None,
+ moe: "Moe | None" = None,
):
super().__init__(
model_args=model_args,
@@ -45,7 +70,15 @@ def __init__(
remove_selected=remove_selected,
)
- self.mla = Mla(model_args=model_args, device_id=device_id, num_devices=num_devices)
+ mla_class = mla_cls or Mla
+ mla_nd = mla_num_devices if mla_num_devices is not None else num_devices
+ self.mla = mla_class(
+ model_args=model_args, device_id=device_id, num_devices=mla_nd, **(mla_kwargs or {})
+ )
self.register_op(self.mla)
- self.moe = Moe(model_args=model_args, device_id=device_id, num_devices=num_devices)
+ self.moe = (
+ moe
+ if moe is not None
+ else Moe(model_args=model_args, device_id=device_id, num_devices=num_devices)
+ )
self.register_op(self.moe)
diff --git a/python/models/deepseek_v3_2/modules/mtp.py b/tilert/models/deepseek_v3_2/modules/mtp.py
similarity index 75%
rename from python/models/deepseek_v3_2/modules/mtp.py
rename to tilert/models/deepseek_v3_2/modules/mtp.py
index fd43e0e..a24101b 100644
--- a/python/models/deepseek_v3_2/modules/mtp.py
+++ b/tilert/models/deepseek_v3_2/modules/mtp.py
@@ -10,7 +10,15 @@
class MTP(SerializableTileRTModule):
"""MTP module."""
- def __init__(self, model_args: ModelArgs, device_id: int, num_devices: int):
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ device_id: int,
+ num_devices: int,
+ mla_cls: type | None = None,
+ mla_num_devices: int | None = None,
+ mla_kwargs: dict | None = None,
+ ):
super().__init__(model_args=model_args, device_id=device_id, num_devices=num_devices)
self.embed_tokens_weight = None
@@ -23,7 +31,14 @@ def __init__(self, model_args: ModelArgs, device_id: int, num_devices: int):
suffix=f"_dev_{device_id}",
)
self.register_op(
- MoeBlock(model_args=model_args, device_id=device_id, num_devices=num_devices),
+ MoeBlock(
+ model_args=model_args,
+ device_id=device_id,
+ num_devices=num_devices,
+ mla_cls=mla_cls,
+ mla_num_devices=mla_num_devices,
+ mla_kwargs=mla_kwargs,
+ ),
prefix=f"layer_{mtp_layer_id}_",
suffix=f"_dev_{device_id}",
)
diff --git a/python/models/deepseek_v3_2/modules/mtp_preprocess.py b/tilert/models/deepseek_v3_2/modules/mtp_preprocess.py
similarity index 95%
rename from python/models/deepseek_v3_2/modules/mtp_preprocess.py
rename to tilert/models/deepseek_v3_2/modules/mtp_preprocess.py
index dc094eb..2a8676a 100644
--- a/python/models/deepseek_v3_2/modules/mtp_preprocess.py
+++ b/tilert/models/deepseek_v3_2/modules/mtp_preprocess.py
@@ -22,10 +22,7 @@ def mtp_preprocess_layer(
temp_vars: list[torch.Tensor],
profile_logs: torch.Tensor,
) -> torch.Tensor:
- """MTP preprocess layer op for DeepSeek v3.
-
- Output is in temp_vars[28] (eh_proj) for DSA temp vars layout.
- """
+ """MTP preprocess layer op for DeepSeek v3."""
return torch.ops.tilert.mtp_preprocess_layer(params, temp_vars, profile_logs)
@@ -90,9 +87,6 @@ def convert_to_tilert(self, weights: list[torch.Tensor], device_id: int) -> list
embedding_rmsnorm_gamma = embedding_rmsnorm_gamma.to(device=device, dtype=torch.float32)
hidden_rmsnorm_gamma = hidden_rmsnorm_gamma.to(device=device, dtype=torch.float32)
- # eh_proj: [out, in] = [7168, 14336]; split on dim=1 -> 8 x [7168, 1792]
- # Reshape: [7168, 1792] -> [128, 56, 7, 256] -> transpose(1,2) -> [128, 7, 56, 256]
- # eh_proj_weight_splited = torch.chunk(eh_proj_weight, self.num_devices, dim=1)
eh_proj_weights = (
eh_proj_weight.reshape(
128, self.model_args.dim // 128, self.model_args.dim * 2 // 256 // 8, 256
diff --git a/tilert/models/deepseek_v3_2/ops/__init__.py b/tilert/models/deepseek_v3_2/ops/__init__.py
new file mode 100644
index 0000000..e62ed9d
--- /dev/null
+++ b/tilert/models/deepseek_v3_2/ops/__init__.py
@@ -0,0 +1,160 @@
+"""Core operations for deepseek v3.2."""
+
+from tilert.models.deepseek_v3_2.ops.broadcast_selected_token_ids import (
+ broadcast_selected_token_ids,
+)
+from tilert.models.deepseek_v3_2.ops.down_allreduce import (
+ DownAllReduce,
+ DownAllReduceAlgorithm,
+ down_allreduce,
+)
+from tilert.models.deepseek_v3_2.ops.eh_proj_allreduce import (
+ EHProjAllReduce,
+ EHProjAllReduceAlgorithm,
+ eh_proj_allreduce,
+)
+from tilert.models.deepseek_v3_2.ops.expert_down_allreduce import (
+ ExpertDownAllReduce,
+ ExpertDownAllReduceAlgorithm,
+ expert_down_allreduce,
+)
+from tilert.models.deepseek_v3_2.ops.expert_sel_up_gate_silu import (
+ ExpertSelectUpGateSiLU,
+ ExpertSelectUpGateSiLUAlgorithm,
+)
+from tilert.models.deepseek_v3_2.ops.flash_sparse_mla import (
+ FlashSparseMLACombineAlgorithm,
+ flash_sparse_mla,
+)
+from tilert.models.deepseek_v3_2.ops.layernorm_rope_rotate import (
+ LayerNormRoPERotateAlgorithm,
+ layernorm_rope_rotate,
+)
+from tilert.models.deepseek_v3_2.ops.padded_allreduce_add import (
+ PaddedAllReduceAdd,
+ PaddedAllReduceAddAlgorithm,
+ padded_allreduce_add,
+)
+from tilert.models.deepseek_v3_2.ops.projo_wkvb import ProjoWKVbAlgorithm, projo_wkvb
+from tilert.models.deepseek_v3_2.ops.projq_wqb import ProjqWqbAlgorithm, projq_wqb
+from tilert.models.deepseek_v3_2.ops.projx_wis import ProjxWisAlgorithm, projx_wis
+from tilert.models.deepseek_v3_2.ops.qkv_rope import (
+ QKVRoPE,
+ QKVRoPEAlgorithm,
+ QKVRoPERefWeightsAlias,
+ QKVRoPETilertWeightsAlias,
+ qkv_rope,
+)
+from tilert.models.deepseek_v3_2.ops.receive_selected_token_ids import (
+ receive_selected_token_ids,
+)
+from tilert.models.deepseek_v3_2.ops.rmsnorm_expert_proj import (
+ RMSNormExpertProj,
+ RMSNormExpertProjAlgorithm,
+)
+from tilert.models.deepseek_v3_2.ops.rmsnorm_head_proj import (
+ RMSNormHeadProj,
+ RMSNormHeadProjAlgorithm,
+)
+from tilert.models.deepseek_v3_2.ops.rmsnorm_kv import KVRMSNormAlgorithm, rmsnorm_kv
+from tilert.models.deepseek_v3_2.ops.rmsnorm_projq_wqb import (
+ RmsnormProjqWqb,
+ RmsnormProjqWqbAlgorithm,
+ RmsnormProjqWqbWeightsConverter,
+)
+from tilert.models.deepseek_v3_2.ops.rmsnorm_projq_wqi import (
+ RmsnormProjqWqi,
+ RmsnormProjqWqiAlgorithm,
+ RmsnormProjqWqiWeightsConverter,
+)
+from tilert.models.deepseek_v3_2.ops.rmsnorm_projx_wqakis import (
+ RMSNormProjxWqakis,
+ RMSNormProjxWqakisAlgorithm,
+)
+from tilert.models.deepseek_v3_2.ops.rmsnorm_projx_wqkva import (
+ RMSNormProjxWqkva,
+ RMSNormProjxWqkvaAlgorithm,
+)
+from tilert.models.deepseek_v3_2.ops.rmsnorm_quant import rmsnorm_quant
+from tilert.models.deepseek_v3_2.ops.rmsnorm_up_gate_silu import (
+ RMSNormUpGateSiLU,
+ RMSNormUpGateSiLUAlgorithm,
+)
+from tilert.models.deepseek_v3_2.ops.rotate import (
+ Rotate,
+ RotateAlgorithm,
+ RotateRefWeightsAlias,
+ RotateTilertWeightsAlias,
+ rotate,
+ rotate_activation,
+)
+from tilert.models.deepseek_v3_2.ops.sparse_index import sparse_index, sparse_index_topk
+from tilert.models.deepseek_v3_2.ops.topk import TopK, topk_accurate, topk_approximate
+from tilert.models.deepseek_v3_2.ops.unproj_o_allreduce import (
+ UnProjOAllReduce,
+ UnProjOAllReduceAlgorithm,
+ unproj_o_allreduce,
+)
+
+__all__ = [
+ "down_allreduce",
+ "DownAllReduce",
+ "DownAllReduceAlgorithm",
+ "expert_down_allreduce",
+ "ExpertDownAllReduce",
+ "ExpertDownAllReduceAlgorithm",
+ "rmsnorm_kv",
+ "KVRMSNormAlgorithm",
+ "unproj_o_allreduce",
+ "projo_wkvb",
+ "ProjoWKVbAlgorithm",
+ "projq_wqb",
+ "ProjqWqbAlgorithm",
+ "rotate",
+ "rotate_activation",
+ "Rotate",
+ "RotateAlgorithm",
+ "RotateRefWeightsAlias",
+ "RotateTilertWeightsAlias",
+ "layernorm_rope_rotate",
+ "LayerNormRoPERotateAlgorithm",
+ "TopK",
+ "topk_approximate",
+ "topk_accurate",
+ "sparse_index",
+ "sparse_index_topk",
+ "flash_sparse_mla",
+ "FlashSparseMLACombineAlgorithm",
+ "projx_wis",
+ "ProjxWisAlgorithm",
+ "qkv_rope",
+ "QKVRoPE",
+ "QKVRoPEAlgorithm",
+ "QKVRoPERefWeightsAlias",
+ "QKVRoPETilertWeightsAlias",
+ "eh_proj_allreduce",
+ "EHProjAllReduceAlgorithm",
+ "rmsnorm_quant",
+ "RmsnormProjqWqi",
+ "RmsnormProjqWqiAlgorithm",
+ "RmsnormProjqWqiWeightsConverter",
+ "RMSNormExpertProj",
+ "RMSNormExpertProjAlgorithm",
+ "RMSNormProjxWqakis",
+ "RMSNormProjxWqakisAlgorithm",
+ "RMSNormProjxWqkva",
+ "RMSNormProjxWqkvaAlgorithm",
+ "RMSNormUpGateSiLU",
+ "RMSNormUpGateSiLUAlgorithm",
+ "UnProjOAllReduce",
+ "UnProjOAllReduceAlgorithm",
+ "RMSNormHeadProj",
+ "RMSNormHeadProjAlgorithm",
+ "ExpertSelectUpGateSiLU",
+ "ExpertSelectUpGateSiLUAlgorithm",
+ "PaddedAllReduceAdd",
+ "PaddedAllReduceAddAlgorithm",
+ "padded_allreduce_add",
+ "broadcast_selected_token_ids",
+ "receive_selected_token_ids",
+]
diff --git a/tilert/models/deepseek_v3_2/ops/broadcast_selected_token_ids.py b/tilert/models/deepseek_v3_2/ops/broadcast_selected_token_ids.py
new file mode 100644
index 0000000..f6bf2a8
--- /dev/null
+++ b/tilert/models/deepseek_v3_2/ops/broadcast_selected_token_ids.py
@@ -0,0 +1,36 @@
+"""BroadcastSelectedTokenIds — P2P broadcast of idx_selects from GPU 0 to peers."""
+
+import torch
+
+__all__ = [
+ "broadcast_selected_token_ids",
+]
+
+
+def broadcast_selected_token_ids(
+ idx_selects: torch.Tensor,
+ peer_bufs: torch.Tensor,
+ flag_val: int,
+ profile_logs: torch.Tensor,
+ model_arch: str,
+ compute_kernel_type: str = "bf16",
+) -> None:
+ """Broadcast idx_selects [1,S,2048] int32 from GPU 0 to peer GPUs.
+
+ Args:
+ idx_selects: Source tensor [1, S, 2048] int32 on GPU 0.
+ peer_bufs: Device pointer array [N] int64 — each entry is a peer
+ buffer address.
+ flag_val: Synchronization flag value.
+ profile_logs: Profile logs tensor.
+ model_arch: Model architecture ("deepseek_v3_2" or "glm_5").
+ compute_kernel_type: Compute kernel type ("bf16").
+ """
+ torch.ops.tilert.broadcast_selected_token_ids_op(
+ idx_selects,
+ peer_bufs,
+ flag_val,
+ model_arch,
+ compute_kernel_type,
+ profile_logs,
+ )
diff --git a/python/models/deepseek_v3_2/ops/down_allreduce.py b/tilert/models/deepseek_v3_2/ops/down_allreduce.py
similarity index 82%
rename from python/models/deepseek_v3_2/ops/down_allreduce.py
rename to tilert/models/deepseek_v3_2/ops/down_allreduce.py
index dfb5a81..cd81461 100644
--- a/python/models/deepseek_v3_2/ops/down_allreduce.py
+++ b/tilert/models/deepseek_v3_2/ops/down_allreduce.py
@@ -1,10 +1,8 @@
"""DownAllreduce operation module."""
-from collections.abc import Callable
from dataclasses import dataclass
from enum import Enum
-# import torch.nn.functional as F
import torch
from tilert.models.base import TileRTModule
@@ -17,7 +15,6 @@
__all__ = [
"down_allreduce",
- "down_allreduce_glm5",
"DownAllReduceAlgorithm",
"DownAllReduce",
"DownAllReduceTilertWeightsAlias",
@@ -32,6 +29,8 @@ def down_allreduce(
flag: int,
vec_out: torch.Tensor,
profile_logs: torch.Tensor,
+ model_arch: str,
+ compute_kernel_type: str = "bf16",
) -> None:
"""
Fused operation of down and allreduce.
@@ -43,10 +42,9 @@ def down_allreduce(
x_in: Input tensor.
flag: Input flag.
vec_out: Output tensor.
- profile_logs: Profile logs tensor. This is a 1D tensor of shape
- (num_sms,) to store the profile logs of the down_allreduce
- operation, where num_sms is the number of SMs on the
- device.
+ profile_logs: Profile logs tensor (1D).
+ model_arch: Model architecture ("deepseek_v3_2" or "glm_5").
+ compute_kernel_type: Compute kernel type ("bf16").
"""
torch.ops.tilert.down_allreduce_op(
vec_in,
@@ -55,41 +53,8 @@ def down_allreduce(
x_in,
flag,
vec_out,
- profile_logs,
- )
-
-
-def down_allreduce_glm5(
- vec_in: torch.Tensor,
- mat_in: torch.Tensor,
- mat_scale: torch.Tensor,
- x_in: torch.Tensor,
- flag: int,
- vec_out: torch.Tensor,
- profile_logs: torch.Tensor,
-) -> None:
- """
- Fused operation of down and allreduce.
-
- Args:
- vec_in: Input tensor.
- mat_in: Input tensor.
- mat_scale: Input tensor.
- x_in: Input tensor.
- flag: Input flag.
- vec_out: Output tensor.
- profile_logs: Profile logs tensor. This is a 1D tensor of shape
- (num_sms,) to store the profile logs of the down_allreduce
- operation, where num_sms is the number of SMs on the
- device.
- """
- torch.ops.tilert.down_allreduce_glm5_op(
- vec_in,
- mat_in,
- mat_scale,
- x_in,
- flag,
- vec_out,
+ model_arch,
+ compute_kernel_type,
profile_logs,
)
@@ -121,6 +86,11 @@ def __call__(self) -> list[str]:
class DownAllReduce(TileRTModule):
"""DownAllReduce module"""
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [DownAllReduceAlgorithm.GENERAL],
+ "glm_5": [DownAllReduceAlgorithm.GENERAL],
+ }
+
def __init__(
self,
model_args: ModelArgs,
@@ -142,7 +112,6 @@ def __init__(
self.moe_inter_dim = self.model_args.moe_inter_dim
self.moe_inter_dim_per_device = self.moe_inter_dim // self.num_devices
self.inter_dim_per_device = self.inter_dim // self.num_devices
- # effective number of experts
self.n_experts: int = self.inter_dim_per_device // self.moe_inter_dim_per_device
self.block_size = self.model_args.block_size
self.dim_scale_dim = self.dim // self.block_size
@@ -150,39 +119,30 @@ def __init__(
self.moe_inter_scale_dim_per_device = self.moe_inter_dim_per_device // self.block_size
self.algorithm = algorithm
- # reference weights
+ if self.arch_name in ("deepseek_v3_2", "glm_5"):
+ self.compute_kernel_type = "bf16"
+ else:
+ raise ValueError(f"Unsupported architecture: {self.arch_name}")
+
+ self.model_arch = self.arch_name
+
self.ref_down: torch.Tensor | None = None
- # tilert weights
self.tilert_weights: torch.Tensor | None = None
self.tilert_scales: torch.Tensor | None = None
- # tilert vars
self.hidden_out: torch.Tensor | None = None
self.profile_logs: torch.Tensor | None = None
self.is_init = False
- # tilert_funcs
- self.down_allreduce_func: Callable | None = None
-
- if self.arch_name == "deepseek_v3_2":
- self.down_allreduce_func = down_allreduce
- elif self.arch_name == "glm_5":
- self.down_allreduce_func = down_allreduce_glm5
- else:
- raise ValueError(f"Unsupported architecture: {self.arch_name}")
-
self.tilert_weights_alias = DownAllReduceTilertWeightsAlias()
- # for device sharding, corresponding to the output of device_sharding
- # and input of tilert_forward
self.tensor_alias: list[str] = [
"down_weights",
"down_scales",
]
- # reference tensor aliases
self.ref_tensor_alias: list[str] = [
"mlp.down_proj.weight",
"mlp.down_proj.weight_scale_inv",
@@ -204,7 +164,7 @@ def get_weights_list(self) -> list[torch.Tensor]:
def device_sharding(
self,
weights_dict: dict[str, torch.Tensor],
- key_prefix: str, # e.g. model.layers.{layer_id}.mlp
+ key_prefix: str,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Device sharding.
@@ -219,7 +179,6 @@ def device_sharding(
down_proj_scale_key = f"{key_prefix}.down_proj.weight_scale_inv"
down_proj_weight = weights_dict[down_proj_weight_key]
down_proj_scale = weights_dict[down_proj_scale_key]
- # To match the old convertcode
down_proj_weight = down_proj_weight.reshape(
self.dim, self.n_experts, self.num_devices, self.moe_inter_dim_per_device
)
@@ -254,7 +213,7 @@ def device_sharding(
def init_reference_weights(
self,
state_dict: dict[str, torch.Tensor],
- key_prefix: str, # e.g. model.layers.{layer_id}.mlp
+ key_prefix: str,
device_id: int = 0,
) -> None:
"""
@@ -295,7 +254,6 @@ def init_tilert_vars(self, batch_size: int, seq_len: int, device_id: int = 0) ->
batch_size: Batch size.
seq_len: Sequence length.
"""
- # tilert vars
self.hidden_out = torch.zeros(
(batch_size, seq_len, self.dim),
dtype=torch.bfloat16,
@@ -304,8 +262,10 @@ def init_tilert_vars(self, batch_size: int, seq_len: int, device_id: int = 0) ->
self.profile_logs = get_profile_log_tensor(device=f"cuda:{device_id}")
self.is_init = True
- def init_random_weights(self, device_id: int = 0) -> None:
+ def init_random_weights(self, device_id: int | None = None) -> None:
"""Initialize the random weights."""
+ if device_id is None:
+ device_id = self.device_id
scale_dtype = torch.float32 if self.arch_name == "glm_5" else torch.bfloat16
down_weights = torch.randn(
self.dim, self.inter_dim, dtype=torch.bfloat16, device=f"cuda:{device_id}"
@@ -364,9 +324,8 @@ def tilert_forward(
x_in: torch.Tensor,
flag: int,
) -> torch.Tensor:
- assert self.down_allreduce_func is not None
assert self.hidden_out is not None
- self.down_allreduce_func(
+ down_allreduce(
vec_in,
self.tilert_weights,
self.tilert_scales,
@@ -374,6 +333,8 @@ def tilert_forward(
flag,
self.hidden_out,
self.profile_logs,
+ self.model_arch,
+ self.compute_kernel_type,
)
return self.hidden_out
diff --git a/tilert/models/deepseek_v3_2/ops/eh_proj_allreduce.py b/tilert/models/deepseek_v3_2/ops/eh_proj_allreduce.py
new file mode 100644
index 0000000..8a72823
--- /dev/null
+++ b/tilert/models/deepseek_v3_2/ops/eh_proj_allreduce.py
@@ -0,0 +1,295 @@
+"""EHProjAllReduce operation module."""
+
+from dataclasses import dataclass
+from enum import Enum
+
+import torch
+
+from tilert.models.base import TileRTModule, TilertWeightsConverter
+from tilert.models.deepseek_v3_2.model_args import ModelArgs
+from tilert.utils import get_profile_log_tensor
+
+__all__ = [
+ "eh_proj_allreduce",
+ "EHProjAllReduceTilertWeightsAlias",
+]
+
+
+def eh_proj_allreduce(
+ vec_in_enorm: torch.Tensor,
+ vec_in_hnorm: torch.Tensor,
+ w_eh: torch.Tensor,
+ flag: int,
+ vec_out: torch.Tensor,
+ profile_logs: torch.Tensor,
+ model_arch: str,
+) -> None:
+ """
+ Fused operation of EHProj and allreduce.
+
+ Args:
+ vec_in_enorm: Input tensor of shape (1, seq_len, 7168).
+ vec_in_hnorm: Input tensor of shape (1, seq_len, 7168).
+ w_eh: Input tensor of shape (7168, 1792) or (128, 7, 56, 256).
+ flag: Input tensor.
+ vec_out: Output tensor of shape (1, seq_len, 7168).
+ profile_logs: Profile logs tensor (1D).
+ model_arch: Model architecture string.
+ """
+ compute_kernel_type = "bf16"
+ torch.ops.tilert.eh_proj_allreduce_op(
+ vec_in_enorm,
+ vec_in_hnorm,
+ w_eh,
+ flag,
+ vec_out,
+ profile_logs,
+ model_arch,
+ compute_kernel_type,
+ torch.empty(0, dtype=torch.int64, device=vec_in_enorm.device),
+ )
+
+
+class EHProjAllReduceAlgorithm(Enum):
+ """EHProjAllReduce algorithm"""
+
+ GENERAL = "general"
+
+
+class EHProjAllReduceWeightsConverter(TilertWeightsConverter):
+ """EHProj weights converter"""
+
+ def convert_to_general(self, weights_list: list[torch.Tensor]) -> tuple[torch.Tensor]:
+ """
+ Convert the weights to general format.
+
+ Args:
+ weights_list: List of weights.
+
+ Returns:
+ Tuple of weights.
+ """
+ args = self.model_args
+ assert args.arch_name == "deepseek_v3_2" or args.arch_name == "glm_5"
+ dim = args.dim
+ num_sms = 128
+ dim_per_sm = dim // num_sms
+ in_dim = dim * 2
+ in_dim_per_device = in_dim // self.num_devices
+ stages = in_dim_per_device // 256
+
+ with torch.inference_mode():
+ (proj_weights,) = weights_list
+ proj_weights = proj_weights.reshape(num_sms, dim_per_sm, stages, 256)
+ proj_weights = proj_weights.transpose(1, 2)
+ return (proj_weights.contiguous(),)
+
+
+@dataclass
+class EHProjAllReduceTilertWeightsAlias:
+ """TileRT weights alias for EHProjAllReduce."""
+
+ eh_proj_weights = "eh_proj_weights"
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return [self.eh_proj_weights]
+
+ def __call__(self) -> list[str]:
+ return self.tilert_tensor_alias
+
+
+class EHProjAllReduce(TileRTModule):
+ """EHProjAllReduce module"""
+
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [EHProjAllReduceAlgorithm.GENERAL],
+ "glm_5": [EHProjAllReduceAlgorithm.GENERAL],
+ }
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ num_devices: int,
+ algorithm: EHProjAllReduceAlgorithm = EHProjAllReduceAlgorithm.GENERAL,
+ ):
+ super().__init__(
+ self.__class__.__name__,
+ model_args=model_args,
+ num_devices=num_devices,
+ )
+
+ self.arch_name = self.model_args.arch_name
+ self.dim = self.model_args.dim
+
+ self.algorithm = algorithm
+
+ self.ref_proj: torch.Tensor | None = None
+
+ self.tilert_proj: torch.Tensor | None = None
+
+ self.hidden_out: torch.Tensor | None = None
+
+ self.profile_logs: torch.Tensor | None = None
+ self.is_init = False
+
+ self.tilert_weights_alias = EHProjAllReduceTilertWeightsAlias()
+
+ self.tensor_alias: list[str] = [
+ "eh_proj_weights",
+ ]
+
+ self.ref_tensor_alias: list[str] = [
+ "eh_proj.weight",
+ ]
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return self.tilert_weights_alias.tilert_tensor_alias
+
+ def get_weights_list(self) -> list[torch.Tensor]:
+ """
+ Get the weights list.
+
+ Returns:
+ List of weights.
+ """
+ return [self.tilert_proj]
+
+ def device_sharding(
+ self,
+ weights_dict: dict[str, torch.Tensor],
+ key_prefix: str | None = None,
+ ) -> tuple[torch.Tensor]:
+ """
+ Device sharding.
+
+ Args:
+ weights_dict: Dictionary of weights.
+ key_prefix: Key prefix.
+ Returns:
+ Tuple of weights.
+ """
+ eh_proj_key = "eh_proj.weight"
+ if key_prefix is not None:
+ eh_proj_key = f"{key_prefix}.eh_proj.weight"
+
+ eh_proj_weight = weights_dict[eh_proj_key]
+ in_dim = eh_proj_weight.shape[1]
+ out_dim = eh_proj_weight.shape[0]
+ in_dim_per_device = in_dim // self.num_devices
+ eh_proj_weight = eh_proj_weight.reshape(out_dim, self.num_devices, in_dim_per_device)
+ eh_proj_weight = eh_proj_weight.transpose(0, 1)
+ return (eh_proj_weight.contiguous(),)
+
+ def init_reference_weights(
+ self,
+ state_dict: dict[str, torch.Tensor],
+ key_prefix: str | None = None,
+ device_id: int = 0,
+ ) -> None:
+ """
+ Initialize the reference weights.
+
+ Args:
+ state_dict: State dictionary.
+ device_id: Device ID.
+ """
+ sharded_list = self.device_sharding(state_dict, key_prefix)
+
+ eh_proj_weight = sharded_list[0][device_id]
+
+ self.ref_proj = eh_proj_weight
+
+ def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ """
+ Initialize the tilert weights.
+
+ Args:
+ state_dict: State dictionary.
+ """
+ assert self.algorithm is not None
+ (self.tilert_proj,) = EHProjAllReduceWeightsConverter(
+ self.model_args, self.num_devices
+ ).dispatch(self.algorithm, [state_dict[alias] for alias in self.tensor_alias])
+
+ def init_tilert_vars(self, batch_size: int, seq_len: int, device_id: int = 0) -> None:
+ """
+ Initialize the tilert variables.
+
+ Args:
+ batch_size: Batch size.
+ seq_len: Sequence length.
+ """
+ self.hidden_out = torch.zeros(
+ (batch_size, seq_len, self.dim),
+ dtype=torch.bfloat16,
+ device=f"cuda:{device_id}",
+ )
+ self.profile_logs = get_profile_log_tensor(device=f"cuda:{device_id}")
+ self.is_init = True
+
+ def init_random_weights(self, device_id: int | None = None) -> None:
+ """Initialize the random weights."""
+ if device_id is None:
+ device_id = self.device_id
+ proj_weights = torch.randn(
+ self.dim, self.dim * 2, dtype=torch.bfloat16, device=f"cuda:{device_id}"
+ )
+
+ tensor_list = [
+ proj_weights,
+ ]
+ state_dict = dict(zip(self.ref_tensor_alias, tensor_list))
+
+ self.init_reference_weights(state_dict, None, device_id)
+ sharded_list = self.device_sharding(state_dict, None)
+ sharded_state_dict = {
+ alias: sharded_list[i][device_id] for i, alias in enumerate(self.tensor_alias)
+ }
+ self.init_tilert_weights(sharded_state_dict)
+
+ def golden_forward(
+ self,
+ vec_in_enorm: torch.Tensor,
+ vec_in_hnorm: torch.Tensor,
+ device_id: int = 0,
+ ) -> torch.Tensor:
+ """
+ Forward pass for the down-project module.
+
+ Args:
+ vec_in_enorm: Input vector of shape (1, seq_len, 7168).
+ vec_in_hnorm: Input vector of shape (1, seq_len, 7168).
+
+ Returns:
+ Output tensor.
+ """
+ assert self.ref_proj is not None
+ bsz = vec_in_enorm.shape[0]
+ assert bsz == 1
+
+ vec_in_concat = torch.cat([vec_in_enorm, vec_in_hnorm], dim=-1)
+ dim_per_device = (self.dim * 2) // self.num_devices
+ vec_in_slice = vec_in_concat[
+ ..., dim_per_device * device_id : dim_per_device * device_id + dim_per_device
+ ]
+ return vec_in_slice @ self.ref_proj.T
+
+ def tilert_forward(
+ self,
+ vec_in_enorm: torch.Tensor,
+ vec_in_hnorm: torch.Tensor,
+ flag: int,
+ ) -> torch.Tensor:
+ assert self.hidden_out is not None
+ eh_proj_allreduce(
+ vec_in_enorm,
+ vec_in_hnorm,
+ self.tilert_proj,
+ flag,
+ self.hidden_out,
+ self.profile_logs,
+ model_arch=self.model_args.arch_name,
+ )
+ return self.hidden_out
diff --git a/tilert/models/deepseek_v3_2/ops/expert_down_allreduce.py b/tilert/models/deepseek_v3_2/ops/expert_down_allreduce.py
new file mode 100644
index 0000000..f7ff7b7
--- /dev/null
+++ b/tilert/models/deepseek_v3_2/ops/expert_down_allreduce.py
@@ -0,0 +1,500 @@
+"""ExpertDownAllreduce operation module."""
+
+from dataclasses import dataclass
+from enum import Enum
+
+import torch
+
+from tilert.models.base import TileRTModule, TilertWeightsConverter
+from tilert.models.common import weight_dequant
+from tilert.models.deepseek_v3_2.model_args import ModelArgs
+from tilert.utils import get_profile_log_tensor
+
+__all__ = [
+ "expert_down_allreduce",
+ "ExpertDownAllReduceAlgorithm",
+ "ExpertDownAllReduce",
+ "ExpertDownAllReduceTilertWeightsAlias",
+]
+
+VALID_SEQ_LENS = (1, 2, 4)
+
+
+def expert_down_allreduce(
+ vec_in: torch.Tensor,
+ mat_in: torch.Tensor,
+ mat_scale: torch.Tensor,
+ indices: torch.Tensor,
+ scores: torch.Tensor,
+ x_in: torch.Tensor,
+ flag: int,
+ vec_out: torch.Tensor,
+ profile_logs: torch.Tensor,
+ model_arch: str,
+ compute_kernel_type: str = "bf16",
+) -> None:
+ """
+ Fused expert down + allreduce (unified for DSv32 and GLM5).
+
+ Args:
+ vec_in: [1, seq_len, n_experts, 256], bfloat16.
+ mat_in: [n_experts, dim, 256], float8_e4m3fn.
+ mat_scale: [n_experts, 1024, 2], bfloat16 (DSv32) or float32 (GLM5).
+ indices: [1, seq_len, 8], int32.
+ scores: [1, seq_len, 8], float32.
+ x_in: [1, seq_len, dim], bfloat16.
+ flag: User flag.
+ vec_out: [1, seq_len, dim], bfloat16 (output).
+ profile_logs: 1D tensor for profile logs.
+ compute_kernel_type: "bf16".
+ """
+ torch.ops.tilert.expert_down_allreduce_op(
+ vec_in,
+ mat_in,
+ mat_scale,
+ indices,
+ scores,
+ x_in,
+ flag,
+ vec_out,
+ profile_logs,
+ model_arch,
+ compute_kernel_type,
+ )
+
+
+class ExpertDownAllReduceAlgorithm(Enum):
+ """ExpertDownAllReduce algorithm."""
+
+ GENERAL = "general"
+ BF16MMA = "bf16mma"
+
+
+class ExpertDownAllReduceWeightsConverter(TilertWeightsConverter):
+ """ExpertDownAllReduce weights converter."""
+
+ @staticmethod
+ def _swizzle_qmma_16x32(mat_in: torch.Tensor) -> torch.Tensor:
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 32
+ assert mat_in.dtype == torch.float8_e4m3fn
+ pre_shape = mat_in.shape[:-2]
+ mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 4).transpose(-4, -3).transpose(-5, -4)
+ return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 4).transpose(-3, -2)
+
+ @staticmethod
+ def _swizzle_qmma_8x32(mat_in: torch.Tensor) -> torch.Tensor:
+ assert mat_in.shape[-2] == 8 and mat_in.shape[-1] == 32
+ pre_shape = mat_in.shape[:-2]
+ return mat_in.reshape(*pre_shape, 8, 2, 4, 4).transpose(-2, -3).contiguous()
+
+ @staticmethod
+ def _swizzle_bf16mma_full_16x32(mat_in: torch.Tensor) -> torch.Tensor:
+ """Swizzle a (16, 32) FP8 sub-block for the BF16 MMA kernel."""
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 32
+ assert mat_in.dtype == torch.float8_e4m3fn
+ pre = mat_in.shape[:-2]
+ mat = mat_in.reshape(*pre, 2, 8, 2, 2, 4, 2)
+ n = len(pre)
+ mat = mat.permute(*range(n), 1 + n, 4 + n, 2 + n, 3 + n, 0 + n, 5 + n)
+ return mat.reshape(*pre, 32, 16).contiguous()
+
+ @staticmethod
+ def _swizzle_bf16mma_partial_8x32(mat_in: torch.Tensor) -> torch.Tensor:
+ """Swizzle a (8, 32) FP8 partial sub-block for the BF16 MMA kernel."""
+ assert mat_in.shape[-2] == 8 and mat_in.shape[-1] == 32
+ assert mat_in.dtype == torch.float8_e4m3fn
+ pre = mat_in.shape[:-2]
+ mat = mat_in.reshape(*pre, 8, 2, 2, 4, 2)
+ n = len(pre)
+ mat = mat.permute(*range(n), 0 + n, 3 + n, 1 + n, 2 + n, 4 + n)
+ return mat.reshape(*pre, 32, 8).contiguous()
+
+ def convert_to_general(
+ self, weights_list: list[torch.Tensor]
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Convert weights to general (tilert) format."""
+ args = self.model_args
+ assert args.arch_name in ("deepseek_v3_2", "glm_5")
+ arch_name = args.arch_name
+ dim = args.dim
+ num_sms = 128
+ dim_per_sm = dim // num_sms
+ dim_scale_dim = dim // args.block_size
+ expert_dim = args.moe_inter_dim // 8
+ k_chunks = expert_dim // 32
+ scale_cols = expert_dim // args.block_size
+
+ with torch.inference_mode():
+ mat_in, scale_in = weights_list
+ exp_num = mat_in.shape[0]
+ mat_in_s = mat_in.reshape(exp_num, num_sms, dim_per_sm, expert_dim)
+ mat_in_0 = (
+ mat_in_s[:, :, :16].reshape(exp_num, num_sms, 16, k_chunks, 32).transpose(2, 3)
+ )
+ mat_in_0 = self._swizzle_qmma_16x32(mat_in_0).reshape(exp_num, 128, -1)
+ mat_in_1 = (
+ mat_in_s[:, :, 16:32].reshape(exp_num, num_sms, 16, k_chunks, 32).transpose(2, 3)
+ )
+ mat_in_1 = self._swizzle_qmma_16x32(mat_in_1).reshape(exp_num, 128, -1)
+ mat_in_2 = (
+ mat_in_s[:, :, 32:48].reshape(exp_num, num_sms, 16, k_chunks, 32).transpose(2, 3)
+ )
+ mat_in_2 = self._swizzle_qmma_16x32(mat_in_2).reshape(exp_num, 128, -1)
+ mats_to_cat = [mat_in_0, mat_in_1, mat_in_2]
+ if arch_name == "deepseek_v3_2":
+ mat_in_3 = (
+ mat_in_s[:, :, 48:56].reshape(exp_num, num_sms, 8, k_chunks, 32).transpose(2, 3)
+ )
+ mat_in_3 = self._swizzle_qmma_8x32(mat_in_3).reshape(exp_num, 128, -1)
+ mats_to_cat.append(mat_in_3)
+ mat_in_swizzled = torch.cat(mats_to_cat, dim=2)
+ mat_in_swizzled = mat_in_swizzled.reshape(exp_num, dim, expert_dim)
+
+ mat_scale_tilert = (
+ scale_in.reshape(exp_num, dim_scale_dim, 1, scale_cols)
+ .repeat(1, 1, 16, 1)
+ .reshape(exp_num, num_sms, -1)
+ )
+ target_cols_per_sm = 1024 * scale_cols // num_sms
+ pad_amount = target_cols_per_sm - mat_scale_tilert.shape[-1]
+ if pad_amount > 0:
+ padding_zeros = torch.zeros(
+ (exp_num, num_sms, pad_amount),
+ dtype=scale_in.dtype,
+ device=scale_in.device,
+ )
+ mat_scale_tilert = torch.cat([mat_scale_tilert, padding_zeros], dim=2)
+ mat_scale_tilert = mat_scale_tilert.reshape(exp_num, 1024, scale_cols)
+ if arch_name == "glm_5":
+ if mat_scale_tilert.dtype != torch.float32:
+ print(
+ "Warning: ExpertDownAllReduceWeightsConverter: "
+ + f"mat_scale_tilert.dtype: {mat_scale_tilert.dtype} "
+ + "is not float32, convert to float32."
+ )
+ mat_scale_tilert = mat_scale_tilert.to(torch.float32)
+ else:
+ mat_scale_tilert = mat_scale_tilert.to(torch.bfloat16)
+ return mat_in_swizzled.contiguous(), mat_scale_tilert.contiguous()
+
+ def convert_to_bf16mma(
+ self, weights_list: list[torch.Tensor]
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Pack FP8 weights for the BF16 MMA kernel (DSv32 only)."""
+ args = self.model_args
+ assert args.arch_name == "deepseek_v3_2", "BF16 MMA only wired for DSv32."
+ dim = args.dim
+ num_sms = 128
+ dim_per_sm = dim // num_sms
+ expert_dim = args.moe_inter_dim // 8
+ k_chunks = expert_dim // 32
+ scale_cols = expert_dim // args.block_size
+ assert dim_per_sm == 56, "BF16 MMA layout currently assumes dim_per_sm=56 (DSv32)."
+
+ with torch.inference_mode():
+ mat_in, scale_in = weights_list
+ exp_num = mat_in.shape[0]
+ mat_per_cta = mat_in.reshape(exp_num, num_sms, dim_per_sm, expert_dim)
+
+ full_part = mat_per_cta[:, :, :48, :]
+ partial_part = mat_per_cta[:, :, 48:, :]
+
+ full_tiles = full_part.reshape(exp_num, num_sms, 3, 16, k_chunks, 32)
+ full_tiles = full_tiles.transpose(3, 4)
+ full_swizzled = self._swizzle_bf16mma_full_16x32(full_tiles)
+ full_swizzled = full_swizzled.reshape(exp_num, num_sms, 3 * k_chunks * 32 * 16)
+
+ partial_tiles = partial_part.reshape(exp_num, num_sms, 1, 8, k_chunks, 32).transpose(
+ 3, 4
+ )
+ partial_swizzled = self._swizzle_bf16mma_partial_8x32(partial_tiles)
+ partial_swizzled = partial_swizzled.reshape(exp_num, num_sms, k_chunks * 32 * 8)
+
+ mat_swizzled = torch.cat([full_swizzled, partial_swizzled], dim=2)
+ mat_swizzled = mat_swizzled.reshape(exp_num, dim, expert_dim)
+
+ mat_scale_tilert = (
+ scale_in.reshape(exp_num, dim // args.block_size, 1, scale_cols)
+ .repeat(1, 1, 16, 1)
+ .reshape(exp_num, num_sms, -1)
+ )
+ target_cols_per_sm = 1024 * scale_cols // num_sms
+ pad_amount = target_cols_per_sm - mat_scale_tilert.shape[-1]
+ if pad_amount > 0:
+ padding_zeros = torch.zeros(
+ (exp_num, num_sms, pad_amount),
+ dtype=scale_in.dtype,
+ device=scale_in.device,
+ )
+ mat_scale_tilert = torch.cat([mat_scale_tilert, padding_zeros], dim=2)
+ mat_scale_tilert = mat_scale_tilert.reshape(exp_num, 1024, scale_cols)
+ mat_scale_tilert = mat_scale_tilert.to(torch.bfloat16)
+
+ return mat_swizzled.contiguous(), mat_scale_tilert.contiguous()
+
+
+@dataclass
+class ExpertDownAllReduceTilertWeightsAlias:
+ """TileRT weights alias for ExpertDownAllReduce."""
+
+ exp_down_weights = "exp_down_weights"
+ exp_down_scales = "exp_down_scales"
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return [self.exp_down_weights, self.exp_down_scales]
+
+ def __call__(self) -> list[str]:
+ return self.tilert_tensor_alias
+
+
+class ExpertDownAllReduce(TileRTModule):
+ """ExpertDownAllReduce module."""
+
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [
+ ExpertDownAllReduceAlgorithm.GENERAL,
+ ExpertDownAllReduceAlgorithm.BF16MMA,
+ ],
+ "glm_5": [ExpertDownAllReduceAlgorithm.GENERAL],
+ }
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ device_id: int,
+ num_devices: int,
+ algorithm: ExpertDownAllReduceAlgorithm = ExpertDownAllReduceAlgorithm.GENERAL,
+ ):
+ super().__init__(
+ self.__class__.__name__,
+ model_args=model_args,
+ device_id=device_id,
+ num_devices=num_devices,
+ )
+ self.arch_name = self.model_args.arch_name
+ self.dim = self.model_args.dim
+ self.n_activated_experts: int = self.model_args.n_activated_experts
+ self.n_routed_experts: int = self.model_args.n_routed_experts
+ self.n_shared_experts: int = self.model_args.n_shared_experts
+ self.moe_inter_dim = self.model_args.moe_inter_dim
+ self.block_size = self.model_args.block_size
+ self.algorithm = algorithm
+
+ self.ref_down: torch.Tensor | None = None
+ self.tilert_weights: torch.Tensor | None = None
+ self.tilert_scales: torch.Tensor | None = None
+ self.hidden_out: torch.Tensor | None = None
+ self.profile_logs: torch.Tensor | None = None
+ self.is_init = False
+
+ if self.arch_name in ("deepseek_v3_2", "glm_5"):
+ self.compute_kernel_type = "bf16"
+ if algorithm == ExpertDownAllReduceAlgorithm.BF16MMA:
+ self.compute_kernel_type = "bf16mma"
+ else:
+ raise ValueError(f"Unsupported architecture: {self.arch_name}")
+
+ self.model_arch = self.arch_name
+
+ self.tilert_weights_alias = ExpertDownAllReduceTilertWeightsAlias()
+ self.tensor_alias = ["exp_down_weights", "exp_down_scales"]
+ self.ref_tensor_alias = (
+ ["mlp.shared_experts.down_proj.weight"]
+ + [f"mlp.experts.{i}.down_proj.weight" for i in range(self.n_routed_experts)]
+ + ["mlp.shared_experts.down_proj.weight_scale_inv"]
+ + [f"mlp.experts.{i}.down_proj.weight_scale_inv" for i in range(self.n_routed_experts)]
+ )
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return self.tilert_weights_alias.tilert_tensor_alias
+
+ def set_algorithm(self, algorithm: Enum) -> None:
+ """Set algorithm and sync compute_kernel_type for BF16MMA dispatch."""
+ super().set_algorithm(algorithm)
+ if algorithm == ExpertDownAllReduceAlgorithm.BF16MMA:
+ self.compute_kernel_type = "bf16mma"
+ elif algorithm == ExpertDownAllReduceAlgorithm.GENERAL:
+ self.compute_kernel_type = "bf16"
+
+ def get_weights_list(self) -> list[torch.Tensor]:
+ return [self.tilert_weights, self.tilert_scales]
+
+ @staticmethod
+ def process_down_weights(
+ key_prefix: str,
+ weights_hf: dict[str, torch.Tensor],
+ num_devices: int,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ down_proj_weight_key = f"{key_prefix}.down_proj.weight"
+ down_proj_scale_key = f"{key_prefix}.down_proj.weight_scale_inv"
+ down_proj_weight = weights_hf[down_proj_weight_key]
+ down_proj_scale = weights_hf[down_proj_scale_key]
+
+ dim = down_proj_weight.shape[-2]
+ dim_scale_dim = down_proj_scale.shape[-2]
+ moe_inter_dim = down_proj_weight.shape[-1]
+ in_scale_dim = down_proj_scale.shape[-1]
+ moe_inter_dim_per_device = moe_inter_dim // num_devices
+ in_scale_dim_per_device = in_scale_dim // num_devices
+
+ down_proj_weight = down_proj_weight.reshape(dim, num_devices, moe_inter_dim_per_device)
+ down_proj_weight = down_proj_weight.transpose(0, 1).reshape(
+ num_devices, 1, dim, moe_inter_dim_per_device
+ )
+ down_proj_scale = down_proj_scale.reshape(
+ dim_scale_dim, num_devices, in_scale_dim_per_device
+ )
+ down_proj_scale = down_proj_scale.transpose(0, 1).reshape(
+ num_devices, 1, dim_scale_dim, in_scale_dim_per_device
+ )
+ return down_proj_weight, down_proj_scale
+
+ def device_sharding(
+ self,
+ weights_dict: dict[str, torch.Tensor],
+ key_prefix: str,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ assert self.n_shared_experts == 1, "Only one shared expert is supported"
+ down_weights_list = []
+ down_scales_list = []
+ exp_prefix = f"{key_prefix}.shared_experts"
+ down_weights, down_scales = self.process_down_weights(
+ exp_prefix, weights_dict, self.num_devices
+ )
+ down_weights_list.append(down_weights)
+ down_scales_list.append(down_scales)
+ for exp_id in range(self.n_routed_experts):
+ exp_prefix = f"{key_prefix}.experts.{exp_id}"
+ down_weights, down_scales = self.process_down_weights(
+ exp_prefix, weights_dict, self.num_devices
+ )
+ down_weights_list.append(down_weights)
+ down_scales_list.append(down_scales)
+ down_weights = torch.cat(down_weights_list, dim=1)
+ down_scales = torch.cat(down_scales_list, dim=1)
+ return down_weights.contiguous(), down_scales.contiguous()
+
+ def init_reference_weights(
+ self,
+ state_dict: dict[str, torch.Tensor],
+ key_prefix: str,
+ device_id: int = 0,
+ ) -> None:
+ sharded_list = self.device_sharding(state_dict, key_prefix)
+ down_weights = sharded_list[0][device_id]
+ down_scales = sharded_list[1][device_id]
+
+ down_list = [
+ weight_dequant(down_weight, down_scale)
+ for down_weight, down_scale in zip(down_weights, down_scales)
+ ]
+ self.ref_down = torch.stack([t.to(torch.bfloat16) for t in down_list], dim=0)
+
+ def get_tilert_weights_alias(self) -> list[str]:
+ """Return the alias list keyed into ``state_dict`` for this op."""
+ return list(self.tilert_weights_alias())
+
+ def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ assert self.algorithm is not None, "Algorithm is not set"
+ self.tilert_weights, self.tilert_scales = ExpertDownAllReduceWeightsConverter(
+ self.model_args, self.num_devices
+ ).dispatch(self.algorithm, [state_dict[alias] for alias in self.tensor_alias])
+
+ def init_tilert_vars(self, batch_size: int, seq_len: int, device_id: int = 0) -> None:
+ self.hidden_out = torch.zeros(
+ (batch_size, seq_len, self.dim),
+ dtype=torch.bfloat16,
+ device=f"cuda:{device_id}",
+ )
+ self.profile_logs = get_profile_log_tensor(device=f"cuda:{device_id}")
+ self.is_init = True
+
+ def init_random_weights(self, device_id: int | None = None) -> None:
+ if device_id is None:
+ device_id = self.device_id
+ n = self.n_routed_experts + 1
+ dev = f"cuda:{device_id}"
+ down_weights = list(
+ torch.randn(n, self.dim, self.moe_inter_dim, dtype=torch.bfloat16, device=dev)
+ .to(torch.float8_e4m3fn)
+ .unbind(0)
+ )
+ dim_scale_dim = self.dim // self.block_size
+ moe_inter_dim_scale_dim = self.moe_inter_dim // self.block_size
+ scale_dtype = torch.float32 if self.arch_name == "glm_5" else torch.bfloat16
+ down_scales = list(
+ torch.randn(
+ n, dim_scale_dim, moe_inter_dim_scale_dim, dtype=scale_dtype, device=dev
+ ).unbind(0)
+ )
+ state_dict = dict(
+ zip(
+ self.ref_tensor_alias,
+ [*down_weights, *down_scales],
+ )
+ )
+ self.init_reference_weights(state_dict, "mlp", device_id)
+ sharded_list = self.device_sharding(state_dict, "mlp")
+ sharded_state_dict = {
+ alias: sharded_list[i][device_id] for i, alias in enumerate(self.tensor_alias)
+ }
+ self.init_tilert_weights(sharded_state_dict)
+
+ def golden_forward(
+ self,
+ vec_in: torch.Tensor,
+ indices: torch.Tensor,
+ scores: torch.Tensor,
+ ) -> torch.Tensor:
+ assert self.ref_down is not None
+ assert vec_in.dim() == 4 and vec_in.size(0) == 1
+ seq_len = vec_in.shape[1]
+ hidden_out_list = []
+ for s in range(seq_len):
+ hidden_out_w2_list = []
+ hidden_out_w2_shared = vec_in[0, s, 0].float() @ self.ref_down[0].float().T
+ hidden_out_w2_list.append(hidden_out_w2_shared)
+ ref_down_sel = self.ref_down[1:][indices[0, s]]
+ for i in range(self.n_activated_experts):
+ hidden_out_w2_sel = vec_in[0, s, i + 1].float() @ ref_down_sel[i].float().T
+ hidden_out_w2_list.append(hidden_out_w2_sel * scores[0, s, i])
+ hidden_out_w2 = torch.stack(hidden_out_w2_list, dim=0).to(torch.bfloat16)
+ hidden_out_w2 = torch.sum(hidden_out_w2, dim=0)
+
+ hidden_out_list.append(hidden_out_w2)
+ hidden_out = torch.stack(hidden_out_list, dim=0)
+ return hidden_out[None, ...]
+
+ def tilert_forward(
+ self,
+ vec_in: torch.Tensor,
+ indices: torch.Tensor,
+ scores: torch.Tensor,
+ x_in: torch.Tensor,
+ flag: int,
+ ) -> torch.Tensor:
+ assert self.hidden_out is not None
+ expert_down_allreduce(
+ vec_in,
+ self.tilert_weights,
+ self.tilert_scales,
+ indices,
+ scores,
+ x_in,
+ flag,
+ self.hidden_out,
+ self.profile_logs,
+ self.model_arch,
+ self.compute_kernel_type,
+ )
+ return self.hidden_out
+
+ def __call__(
+ self,
+ x_in: torch.Tensor,
+ indices: torch.Tensor,
+ scores: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.golden_forward(x_in, indices, scores)
diff --git a/tilert/models/deepseek_v3_2/ops/expert_sel_up_gate_silu.py b/tilert/models/deepseek_v3_2/ops/expert_sel_up_gate_silu.py
new file mode 100644
index 0000000..3e663bf
--- /dev/null
+++ b/tilert/models/deepseek_v3_2/ops/expert_sel_up_gate_silu.py
@@ -0,0 +1,713 @@
+"""ExpertSelectUpGateSiLU operation module."""
+
+from dataclasses import dataclass
+from enum import Enum
+
+import torch
+import torch.nn.functional as F
+
+from tilert.models.base import TileRTModule, TilertWeightsConverter
+from tilert.models.common import weight_dequant
+from tilert.models.deepseek_v3_2.model_args import ModelArgs
+from tilert.utils import get_profile_log_tensor
+
+__all__ = [
+ "ExpertSelectUpGateSiLUAlgorithm",
+ "ExpertSelectUpGateSiLU",
+ "ExpertSelectUpGateSiLURefWeightsAlias",
+ "ExpertSelectUpGateSiLUTilertWeightsAlias",
+ "expert_select_up_gate_silu",
+]
+
+
+def expert_select_up_gate_silu(
+ hidden_in: torch.Tensor,
+ scores_in: torch.Tensor,
+ bias_in: torch.Tensor,
+ experts_weights_in: torch.Tensor,
+ hidden_out: torch.Tensor,
+ expert_probs_out: torch.Tensor,
+ expert_indices_out: torch.Tensor,
+ profile_logs: torch.Tensor,
+ algorithm: str = "fp8mma",
+ *,
+ model_arch: str,
+) -> None:
+ """Expert SelectUpGateSiLU operation."""
+ torch.ops.tilert.expert_select_up_gate_silu_op(
+ hidden_in,
+ scores_in,
+ bias_in,
+ experts_weights_in,
+ hidden_out,
+ expert_probs_out,
+ expert_indices_out,
+ profile_logs,
+ model_arch,
+ algorithm,
+ )
+
+
+@dataclass
+class ExpertSelectUpGateSiLURefWeightsAlias:
+ """Reference weights alias for ExpertSelectUpGateSiLU."""
+
+ key_prefix: str = "mlp"
+ n_routed_experts: int = 256
+
+ @property
+ def ref_tensor_alias(self) -> list[str]:
+ n = self.n_routed_experts
+ return (
+ [f"{self.key_prefix}.gate.e_score_correction_bias"]
+ + [f"{self.key_prefix}.shared_experts.gate_proj.weight"]
+ + [f"{self.key_prefix}.experts.{i}.gate_proj.weight" for i in range(n)]
+ + [f"{self.key_prefix}.shared_experts.up_proj.weight"]
+ + [f"{self.key_prefix}.experts.{i}.up_proj.weight" for i in range(n)]
+ + [f"{self.key_prefix}.shared_experts.gate_proj.weight_scale_inv"]
+ + [f"{self.key_prefix}.experts.{i}.gate_proj.weight_scale_inv" for i in range(n)]
+ + [f"{self.key_prefix}.shared_experts.up_proj.weight_scale_inv"]
+ + [f"{self.key_prefix}.experts.{i}.up_proj.weight_scale_inv" for i in range(n)]
+ )
+
+ def __call__(self) -> list[str]:
+ return self.ref_tensor_alias
+
+
+@dataclass
+class ExpertSelectUpGateSiLUTilertWeightsAlias:
+ """TileRT weights alias for ExpertSelectUpGateSiLU."""
+
+ exp_bias = "exp_bias"
+ exp_gate_weights = "exp_gate_weights"
+ exp_gate_scales = "exp_gate_scales"
+ exp_up_weights = "exp_up_weights"
+ exp_up_scales = "exp_up_scales"
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return [
+ self.exp_bias,
+ self.exp_gate_weights,
+ self.exp_gate_scales,
+ self.exp_up_weights,
+ self.exp_up_scales,
+ ]
+
+ def __call__(self) -> list[str]:
+ return self.tilert_tensor_alias
+
+
+class ExpertSelectUpGateSiLUAlgorithm(Enum):
+ """ExpertSelectUpGateSiLU algorithm"""
+
+ FP8MMA = "fp8mma"
+ FP16MMA = "fp16mma"
+ BF16MMA = "bf16mma"
+
+
+class ExpertSelectUpGateSiLUWeightsConverter(TilertWeightsConverter):
+ """ExpertSelectUpGateSiLU weights converter"""
+
+ @staticmethod
+ def _swizzle_qmma_16x32(mat_in: torch.Tensor) -> torch.Tensor:
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 32
+ assert mat_in.dtype == torch.float8_e4m3fn
+ pre_shape = mat_in.shape[:-2]
+ mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 4).transpose(-4, -3).transpose(-5, -4)
+ return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 4).transpose(-3, -2)
+
+ @staticmethod
+ def _swizzle_mma_16x32(mat_in: torch.Tensor) -> torch.Tensor:
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 32
+ pre_shape = mat_in.shape[:-2]
+ mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 4).transpose(-4, -3).transpose(-5, -4)
+ return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 4).transpose(-3, -2)
+
+ @staticmethod
+ def _swizzle_mma_16x16(mat_in: torch.Tensor) -> torch.Tensor:
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 16
+ pre_shape = mat_in.shape[:-2]
+ mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 2).transpose(-4, -3).transpose(-5, -4)
+ return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 2).transpose(-3, -2)
+
+ @staticmethod
+ def tilert_to_tilert_144sm(
+ mat_in: torch.Tensor, mat_scale_in: torch.Tensor, mma_type: str | None = None
+ ) -> torch.Tensor:
+ """
+ Convert tilert weights and scales to tilert_144sm input format.
+
+ Args:
+ mat_in: tilert weights
+ mat_scale_in: tilert scales
+ mma_type: MMA type, None,"16x32" or "16x16"
+ Returns:
+ tilert_144sm weights and scales
+ """
+ exp_num = mat_in.shape[0]
+ assert mat_in.shape == (exp_num, 512, 7168)
+ assert mat_scale_in.shape == (exp_num, 4, 64)
+ weights_trt = mat_in.reshape(exp_num, 128, 4, 7168)
+ weights_w1 = weights_trt[:, :, :2].reshape(exp_num, 256, 7168)
+ weights_w3 = weights_trt[:, :, 2:].reshape(exp_num, 256, 7168)
+ weights_w1 = weights_w1.reshape(exp_num, 16, 16, 7, 1024).transpose(2, 3)
+ weights_w3 = weights_w3.reshape(exp_num, 16, 16, 7, 1024).transpose(2, 3)
+ if mma_type == "16x32":
+ weights_w1 = weights_w1.reshape(exp_num, 16, 7, 16, 32, 32).transpose(3, 4)
+ weights_w1 = ExpertSelectUpGateSiLUWeightsConverter._swizzle_mma_16x32(weights_w1)
+ weights_w1 = weights_w1.reshape(exp_num, 16, 7, 16, 1024)
+ weights_w3 = weights_w3.reshape(exp_num, 16, 7, 16, 32, 32).transpose(3, 4)
+ weights_w3 = ExpertSelectUpGateSiLUWeightsConverter._swizzle_mma_16x32(weights_w3)
+ weights_w3 = weights_w3.reshape(exp_num, 16, 7, 16, 1024)
+ elif mma_type == "16x16":
+ weights_w1 = weights_w1.reshape(exp_num, 16, 7, 16, 64, 16).transpose(3, 4)
+ weights_w1 = ExpertSelectUpGateSiLUWeightsConverter._swizzle_mma_16x16(weights_w1)
+ weights_w1 = weights_w1.reshape(exp_num, 16, 7, 16, 1024)
+ weights_w3 = weights_w3.reshape(exp_num, 16, 7, 16, 64, 16).transpose(3, 4)
+ weights_w3 = ExpertSelectUpGateSiLUWeightsConverter._swizzle_mma_16x16(weights_w3)
+ weights_w3 = weights_w3.reshape(exp_num, 16, 7, 16, 1024)
+
+ weights = torch.cat([weights_w1, weights_w3], dim=3)
+ assert weights.shape == (exp_num, 16, 7, 32, 1024)
+ weights = weights.reshape(exp_num, 16, 7, 32 * 1024)
+
+ scales_unswizzled = torch.zeros(exp_num, 4, 56)
+ for i in range(64):
+ if ((i % 8) * 8 + i // 8) < 56:
+ scales_unswizzled[..., ((i % 8) * 8 + i // 8)] = mat_scale_in[..., i]
+ scales_unswizzled = scales_unswizzled.reshape(exp_num, 2, 2, 56)
+
+ scales_w1 = scales_unswizzled[:, :, :1].repeat(1, 1, 8, 1).reshape(exp_num, 16, 1, 7, 8)
+ scales_w1 = scales_w1.transpose(2, 3)
+ scales_w3 = scales_unswizzled[:, :, 1:].repeat(1, 1, 8, 1).reshape(exp_num, 16, 1, 7, 8)
+ scales_w3 = scales_w3.transpose(2, 3)
+ scales = torch.cat([scales_w1, scales_w3], dim=3)
+ assert scales.shape == (exp_num, 16, 7, 2, 8)
+ scales = (
+ scales.reshape(exp_num, 16, 7, 2 * 8).to(torch.bfloat16).view(dtype=torch.float8_e4m3fn)
+ )
+ weights_and_scales = torch.zeros(
+ exp_num, 16, 7, 32 * 1024 + 128, dtype=torch.float8_e4m3fn, device=mat_in.device
+ )
+ weights_and_scales[:, :, :, : 32 * 1024].copy_(weights)
+ weights_and_scales[:, :, :, 32 * 1024 : 32 * 1024 + 32].copy_(scales)
+ return weights_and_scales
+
+ @staticmethod
+ def tilert_to_tilert_144sm_mma(
+ mat_in: torch.Tensor, mat_scale_in: torch.Tensor, mma_type: str = "16x32"
+ ) -> torch.Tensor:
+ """
+ Convert tilert weights and scales to tilert_144sm_mma input format.
+
+ Args:
+ mat_in: tilert weights
+ mat_scale_in: tilert scales
+ Returns:
+ tilert_144sm weights and scales
+ """
+ return ExpertSelectUpGateSiLUWeightsConverter.tilert_to_tilert_144sm(
+ mat_in, mat_scale_in, mma_type
+ )
+
+ def convert_to_mma(
+ self, weights_list: list[torch.Tensor], algorithm: str = "fp8mma"
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Convert the weights to mma format."""
+ args = self.model_args
+ dim = args.dim
+ pages = dim // 1024
+ dim_scale_dim = dim // args.block_size
+ with torch.inference_mode():
+ bias_or_gamma, weights_w1, scales_w1, weights_w3, scales_w3 = weights_list
+ exp_num = weights_w1.shape[0]
+ moe_rows = weights_w1.shape[1]
+ n_row_groups = moe_rows // 16
+ scale_m_dim = moe_rows // args.block_size
+ weights_w1 = weights_w1.reshape(exp_num, n_row_groups, 16, pages, 1024).transpose(2, 3)
+ weights_w3 = weights_w3.reshape(exp_num, n_row_groups, 16, pages, 1024).transpose(2, 3)
+ if algorithm == "fp8mma":
+ weights_w1 = weights_w1.reshape(exp_num, n_row_groups, pages, 16, 32, 32).transpose(
+ 3, 4
+ )
+ weights_w1 = self._swizzle_qmma_16x32(weights_w1)
+ weights_w1 = weights_w1.reshape(exp_num, n_row_groups, pages, 16, 1024)
+ weights_w3 = weights_w3.reshape(exp_num, n_row_groups, pages, 16, 32, 32).transpose(
+ 3, 4
+ )
+ weights_w3 = self._swizzle_qmma_16x32(weights_w3)
+ weights_w3 = weights_w3.reshape(exp_num, n_row_groups, pages, 16, 1024)
+ elif algorithm == "fp16mma":
+ weights_w1 = weights_w1.reshape(exp_num, n_row_groups, pages, 16, 64, 16).transpose(
+ 3, 4
+ )
+ weights_w1 = self._swizzle_mma_16x16(weights_w1)
+ weights_w1 = weights_w1.reshape(exp_num, n_row_groups, pages, 16, 1024)
+ weights_w3 = weights_w3.reshape(exp_num, n_row_groups, pages, 16, 64, 16).transpose(
+ 3, 4
+ )
+ weights_w3 = self._swizzle_mma_16x16(weights_w3)
+ weights_w3 = weights_w3.reshape(exp_num, n_row_groups, pages, 16, 1024)
+ else:
+ raise ValueError(f"Unsupported algorithm: {algorithm}")
+ weights: torch.Tensor = torch.cat([weights_w1, weights_w3], dim=3)
+ assert weights.shape == (exp_num, n_row_groups, pages, 32, 1024)
+ weights = weights.reshape(exp_num, n_row_groups, pages, 32 * 1024)
+
+ scales_per_page = 1024 // args.block_size
+ repeat_factor = n_row_groups // scale_m_dim
+ scales_w1 = (
+ scales_w1.reshape(exp_num, scale_m_dim, 1, dim_scale_dim)
+ .repeat(1, 1, repeat_factor, 1)
+ .reshape(exp_num, n_row_groups, 1, pages, scales_per_page)
+ )
+ scales_w1 = scales_w1.transpose(2, 3)
+ scales_w3 = (
+ scales_w3.reshape(exp_num, scale_m_dim, 1, dim_scale_dim)
+ .repeat(1, 1, repeat_factor, 1)
+ .reshape(exp_num, n_row_groups, 1, pages, scales_per_page)
+ )
+ scales_w3 = scales_w3.transpose(2, 3)
+ scales = torch.cat([scales_w1, scales_w3], dim=3)
+ assert scales.shape == (exp_num, n_row_groups, pages, 2, scales_per_page)
+
+ if self.model_args.arch_name == "glm_5":
+ if scales.dtype != torch.float32:
+ print(
+ "Warning: ExpertSelectUpGateSiLUWeightsConverter: "
+ + f"scales.dtype: {scales.dtype} "
+ + "is not float32, convert to float32."
+ )
+ scales = scales.to(torch.float32)
+ else:
+ scales = scales.to(torch.bfloat16)
+
+ scales = scales.reshape(exp_num, n_row_groups, pages, 2 * scales_per_page).view(
+ dtype=torch.float8_e4m3fn
+ )
+
+ weights_and_scales = torch.zeros(
+ exp_num,
+ n_row_groups,
+ pages,
+ 32 * 1024 + 128,
+ dtype=torch.float8_e4m3fn,
+ device=weights_w1.device,
+ )
+ weights_and_scales[:, :, :, : 32 * 1024].copy_(weights)
+ weights_and_scales[:, :, :, 32 * 1024 : 32 * 1024 + scales.shape[-1]].copy_(scales)
+
+ return bias_or_gamma.float(), weights_and_scales.contiguous()
+
+ def convert_to_fp8mma(
+ self, weights_list: list[torch.Tensor]
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Convert the weights to fp8mma format.
+
+ Args:
+ weights: List of weights.
+
+ Returns:
+ Tuple of weights.
+ """
+ return self.convert_to_mma(weights_list, "fp8mma")
+
+ def convert_to_fp16mma(
+ self, weights_list: list[torch.Tensor]
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Convert the weights to fp16mma format.
+
+ Args:
+ weights: List of weights.
+
+ Returns:
+ Tuple of weights.
+ """
+ return self.convert_to_mma(weights_list, "fp16mma")
+
+ def convert_to_bf16mma(
+ self, weights_list: list[torch.Tensor]
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Convert the weights to bf16mma format."""
+ return self.convert_to_mma(weights_list, "fp16mma")
+
+
+class ExpertSelectUpGateSiLU(TileRTModule):
+ """ExpertSelectUpGateSiLU module"""
+
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [
+ ExpertSelectUpGateSiLUAlgorithm.FP8MMA,
+ ExpertSelectUpGateSiLUAlgorithm.FP16MMA,
+ ExpertSelectUpGateSiLUAlgorithm.BF16MMA,
+ ],
+ "glm_5": [
+ ExpertSelectUpGateSiLUAlgorithm.FP8MMA,
+ ExpertSelectUpGateSiLUAlgorithm.FP16MMA,
+ ],
+ }
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ num_devices: int,
+ device_id: int = 0,
+ ref_weights_alias: ExpertSelectUpGateSiLURefWeightsAlias | None = None,
+ tilert_weights_alias: ExpertSelectUpGateSiLUTilertWeightsAlias | None = None,
+ algorithm: ExpertSelectUpGateSiLUAlgorithm = ExpertSelectUpGateSiLUAlgorithm.FP8MMA,
+ ):
+ super().__init__(
+ self.__class__.__name__,
+ model_args=model_args,
+ num_devices=num_devices,
+ device_id=device_id,
+ )
+
+ self.arch_name = self.model_args.arch_name
+ self.dim = self.model_args.dim
+
+ self.n_activated_experts = self.model_args.n_activated_experts
+ self.n_routed_experts = self.model_args.n_routed_experts
+ self.n_shared_experts = self.model_args.n_shared_experts
+ self.moe_inter_dim = self.model_args.moe_inter_dim
+ self.n_expert_groups = self.model_args.n_expert_groups
+ self.n_limited_groups = self.model_args.n_limited_groups
+ self.route_scale = self.model_args.route_scale
+ self.block_size = self.model_args.block_size
+ self.algorithm = algorithm
+
+ self.tilert_weights_alias = (
+ tilert_weights_alias
+ if tilert_weights_alias is not None
+ else ExpertSelectUpGateSiLUTilertWeightsAlias()
+ )
+ self.ref_weights_alias = (
+ ref_weights_alias
+ if ref_weights_alias is not None
+ else ExpertSelectUpGateSiLURefWeightsAlias(
+ key_prefix="mlp", n_routed_experts=self.n_routed_experts
+ )
+ )
+
+ self.ref_bias: torch.Tensor | None = None
+ self.ref_gate: torch.Tensor | None = None
+ self.ref_up: torch.Tensor | None = None
+
+ self.tilert_bias: torch.Tensor | None = None
+ self.tilert_weights: torch.Tensor | None = None
+ self.tilert_scales = (
+ torch.zeros(1, dtype=torch.bfloat16, device=torch.device("cuda"))
+ if torch.cuda.is_available()
+ else None
+ )
+
+ self.hidden_out: torch.Tensor | None = None
+ self.expert_probs: torch.Tensor | None = None
+ self.expert_indices: torch.Tensor | None = None
+
+ self.profile_logs: torch.Tensor | None = None
+ self.is_init = False
+
+ self._tensor_alias = self.tilert_weights_alias()
+ self._tilert_tensor_alias = [
+ self.tilert_weights_alias.exp_bias,
+ "exp_upgate_weights",
+ "exp_upgate_scales",
+ ]
+
+ @property
+ def tensor_alias(self) -> list[str]:
+ return self._tensor_alias
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ """Output weight names for get_weights_list (backward compat)."""
+ return self._tilert_tensor_alias
+
+ def get_weights_list(self) -> list[torch.Tensor]:
+ """
+ Get the weights list.
+
+ Returns:
+ List of weights.
+ """
+ return [self.tilert_bias, self.tilert_weights, self.tilert_scales]
+
+ @staticmethod
+ def process_gate_up_weights(
+ key_prefix: str,
+ weights_hf: dict[str, torch.Tensor],
+ num_devices: int,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ gate_proj_weight_key = f"{key_prefix}.gate_proj.weight"
+ gate_proj_scale_key = f"{key_prefix}.gate_proj.weight_scale_inv"
+ up_proj_weight_key = f"{key_prefix}.up_proj.weight"
+ up_proj_scale_key = f"{key_prefix}.up_proj.weight_scale_inv"
+
+ gate_proj_weight = weights_hf[gate_proj_weight_key]
+ gate_proj_scale = weights_hf[gate_proj_scale_key]
+ up_proj_weight = weights_hf[up_proj_weight_key]
+ up_proj_scale = weights_hf[up_proj_scale_key]
+ dim = gate_proj_weight.shape[-1]
+ in_dim = gate_proj_weight.shape[-2]
+ scale_dim = gate_proj_scale.shape[-1]
+ in_scale_dim = gate_proj_scale.shape[-2]
+ in_dim_per_device = in_dim // num_devices
+ in_scale_dim_per_device = in_scale_dim // num_devices
+ gate_proj_weight = gate_proj_weight.reshape(num_devices, 1, in_dim_per_device, dim)
+ gate_proj_scale = gate_proj_scale.reshape(
+ num_devices, 1, in_scale_dim_per_device, scale_dim
+ )
+ up_proj_weight = up_proj_weight.reshape(num_devices, 1, in_dim_per_device, dim)
+ up_proj_scale = up_proj_scale.reshape(num_devices, 1, in_scale_dim_per_device, scale_dim)
+ return gate_proj_weight, gate_proj_scale, up_proj_weight, up_proj_scale
+
+ def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ """
+ Device sharding: ref state dict -> tilert sharded tensors (num_devices, ...).
+
+ Args:
+ weights_map: State dict keyed by ref_weights_alias().
+
+ Returns:
+ Dict keyed by tilert_weights_alias() with (num_devices, ...) tensors.
+ """
+ ref_alias = self.ref_weights_alias
+ key_prefix = ref_alias.key_prefix
+
+ bias_key = f"{key_prefix}.gate.e_score_correction_bias"
+ bias = weights_map[bias_key]
+ bias = bias[None, :].repeat(self.num_devices, 1)
+
+ gate_weights_list = []
+ gate_scales_list = []
+ up_weights_list = []
+ up_scales_list = []
+ assert self.n_shared_experts == 1, "Only one shared expert is supported"
+ exp_prefix = f"{key_prefix}.shared_experts"
+ gate_weights, gate_scales, up_weights, up_scales = self.process_gate_up_weights(
+ exp_prefix, weights_map, self.num_devices
+ )
+ gate_weights_list.append(gate_weights)
+ gate_scales_list.append(gate_scales)
+ up_weights_list.append(up_weights)
+ up_scales_list.append(up_scales)
+
+ for exp_id in range(self.n_routed_experts):
+ exp_prefix = f"{key_prefix}.experts.{exp_id}"
+ gate_weights, gate_scales, up_weights, up_scales = self.process_gate_up_weights(
+ exp_prefix, weights_map, self.num_devices
+ )
+ gate_weights_list.append(gate_weights)
+ gate_scales_list.append(gate_scales)
+ up_weights_list.append(up_weights)
+ up_scales_list.append(up_scales)
+
+ gate_weights = torch.cat(gate_weights_list, dim=1)
+ gate_scales = torch.cat(gate_scales_list, dim=1)
+ up_weights = torch.cat(up_weights_list, dim=1)
+ up_scales = torch.cat(up_scales_list, dim=1)
+ tilert_alias = self.tilert_weights_alias
+ return {
+ tilert_alias.exp_bias: bias,
+ tilert_alias.exp_gate_weights: gate_weights,
+ tilert_alias.exp_gate_scales: gate_scales,
+ tilert_alias.exp_up_weights: up_weights,
+ tilert_alias.exp_up_scales: up_scales,
+ }
+
+ def init_reference_weights(
+ self,
+ state_dict: dict[str, torch.Tensor],
+ device_id: int | None = None,
+ ) -> None:
+ """
+ Initialize the reference weights.
+
+ Args:
+ state_dict: State dict keyed by ref_weights_alias().
+ device_id: Device ID; defaults to self.device_id.
+ """
+ did = self.device_id if device_id is None else device_id
+ sharded = self.device_sharding(state_dict)
+
+ tilert_alias = self.tilert_weights_alias
+ bias = sharded[tilert_alias.exp_bias][did]
+ gate_weights = sharded[tilert_alias.exp_gate_weights][did]
+ gate_scales = sharded[tilert_alias.exp_gate_scales][did]
+ up_weights = sharded[tilert_alias.exp_up_weights][did]
+ up_scales = sharded[tilert_alias.exp_up_scales][did]
+
+ self.ref_bias = bias
+ ref_gate_list = [
+ weight_dequant(gate_weights[i], gate_scales[i]) for i in range(gate_weights.shape[0])
+ ]
+ ref_up_list = [
+ weight_dequant(up_weights[i], up_scales[i]) for i in range(up_weights.shape[0])
+ ]
+ self.ref_gate = torch.stack([t.to(torch.bfloat16) for t in ref_gate_list], dim=0)
+ self.ref_up = torch.stack([t.to(torch.bfloat16) for t in ref_up_list], dim=0)
+
+ def get_tilert_weights_alias(self) -> list[str]:
+ """Return the alias list keyed into ``state_dict`` for this op."""
+ return list(self.tilert_weights_alias())
+
+ def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ """Initialize the tilert weights."""
+ assert self.algorithm is not None, "Algorithm is not set"
+ weights_list = [state_dict[alias] for alias in self.tilert_weights_alias()]
+ converter = ExpertSelectUpGateSiLUWeightsConverter(self.model_args, self.num_devices)
+ self.tilert_bias, self.tilert_weights = converter.dispatch(self.algorithm, weights_list)
+
+ def init_tilert_vars(self, batch_size: int, seq_len: int, device: str = "cuda") -> None:
+ """
+ Initialize the tilert variables.
+
+ Args:
+ batch_size: Batch size.
+ seq_len: Sequence length.
+ """
+ self.hidden_out = torch.zeros(
+ (
+ batch_size,
+ seq_len,
+ self.n_activated_experts + self.n_shared_experts,
+ self.moe_inter_dim // self.num_devices,
+ ),
+ dtype=torch.bfloat16,
+ device=device,
+ )
+ self.expert_probs = torch.zeros(
+ (batch_size, seq_len, self.n_activated_experts),
+ dtype=torch.float32,
+ device=device,
+ )
+ self.expert_indices = torch.zeros(
+ (batch_size, seq_len, self.n_activated_experts),
+ dtype=torch.int32,
+ device=device,
+ )
+
+ self.profile_logs = get_profile_log_tensor(device=device)
+ self.is_init = True
+
+ def init_random_weights(self, device: str = "cuda") -> None:
+ """
+ Initialize the random weights.
+
+ Returns:
+ None
+ """
+ n = self.n_routed_experts + 1
+ bias = torch.randn(self.n_routed_experts, dtype=torch.float32, device=device)
+ gate_weights = list(
+ torch.randn(n, self.moe_inter_dim, self.dim, dtype=torch.bfloat16, device=device)
+ .to(torch.float8_e4m3fn)
+ .unbind(0)
+ )
+ up_weights = list(
+ torch.randn(n, self.moe_inter_dim, self.dim, dtype=torch.bfloat16, device=device)
+ .to(torch.float8_e4m3fn)
+ .unbind(0)
+ )
+ moe_inter_dim_scale_dim = self.moe_inter_dim // self.block_size
+ dim_scale_dim = self.dim // self.block_size
+ scale_dtype = torch.float32 if self.arch_name == "glm_5" else torch.bfloat16
+ gate_scales = list(
+ torch.randn(
+ n, moe_inter_dim_scale_dim, dim_scale_dim, dtype=scale_dtype, device=device
+ ).unbind(0)
+ )
+ up_scales = list(
+ torch.randn(
+ n, moe_inter_dim_scale_dim, dim_scale_dim, dtype=scale_dtype, device=device
+ ).unbind(0)
+ )
+ tensor_list = [
+ bias,
+ *gate_weights,
+ *up_weights,
+ *gate_scales,
+ *up_scales,
+ ]
+ ref_state_dict = dict(zip(self.ref_weights_alias(), tensor_list))
+ self.init_reference_weights(ref_state_dict)
+ sharded = self.device_sharding(ref_state_dict)
+ per_device_state = {k: v[self.device_id] for k, v in sharded.items()}
+ self.init_tilert_weights(per_device_state)
+
+ def _ref_expert_select_glm5(self, scores: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ scores = scores.sigmoid()
+ original_scores = scores
+ if self.ref_bias is not None:
+ scores = scores + self.ref_bias
+ indices = torch.topk(scores, self.n_activated_experts, dim=-1)[1]
+ indices = indices.view(*original_scores.shape[:-1], self.n_activated_experts)
+ weights = original_scores.gather(-1, indices)
+ weights /= weights.sum(dim=-1, keepdim=True)
+ weights *= self.route_scale
+ return weights, indices
+
+ def golden_forward(
+ self,
+ x_in: torch.Tensor,
+ scores: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ assert self.ref_gate is not None
+ assert self.ref_up is not None
+ bsz = x_in.shape[0]
+ seq_len = x_in.shape[1]
+ assert bsz == 1
+ if self.arch_name == "deepseek_v3_2":
+ weights, indices = self._ref_expert_select_ds(scores)
+ elif self.arch_name == "glm_5":
+ weights, indices = self._ref_expert_select_glm5(scores)
+ else:
+ raise ValueError(f"Unsupported architecture: {self.arch_name}")
+ hidden_out_list = []
+ for s in range(seq_len):
+ hidden_out_w1_list = []
+ hidden_out_w3_list = []
+ hidden_out_w1_shared = x_in[0, s].float() @ self.ref_gate[0].float().T
+ hidden_out_w3_shared = x_in[0, s].float() @ self.ref_up[0].float().T
+ hidden_out_w1_list.append(hidden_out_w1_shared)
+ hidden_out_w3_list.append(hidden_out_w3_shared)
+ ref_gate_sel = self.ref_gate[1:][indices[0, s]]
+ ref_up_sel = self.ref_up[1:][indices[0, s]]
+ for i in range(self.n_activated_experts):
+ hidden_out_w1_sel = x_in[0, s].float() @ ref_gate_sel[i].float().T
+ hidden_out_w3_sel = x_in[0, s].float() @ ref_up_sel[i].float().T
+ hidden_out_w1_list.append(hidden_out_w1_sel)
+ hidden_out_w3_list.append(hidden_out_w3_sel)
+ hidden_out_w1 = torch.stack(hidden_out_w1_list, dim=0)
+ hidden_out_w3 = torch.stack(hidden_out_w3_list, dim=0)
+ hidden_out = F.silu(hidden_out_w1.float()) * hidden_out_w3.float()
+ hidden_out = hidden_out.to(torch.bfloat16)
+ hidden_out_list.append(hidden_out)
+ hidden_out = torch.stack(hidden_out_list, dim=0)
+ hidden_out = hidden_out[None, ...]
+ return hidden_out, weights, indices
+
+ def tilert_forward(
+ self,
+ x_in: torch.Tensor,
+ scores: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Run the kernel."""
+ assert self.algorithm is not None, "Algorithm is not set"
+ expert_select_up_gate_silu(
+ x_in,
+ scores,
+ self.tilert_bias,
+ self.tilert_weights,
+ self.hidden_out,
+ self.expert_probs,
+ self.expert_indices,
+ self.profile_logs,
+ self.algorithm.value,
+ model_arch=self.model_args.arch_name,
+ )
+ return self.hidden_out, self.expert_probs, self.expert_indices
diff --git a/python/models/deepseek_v3_2/ops/flash_sparse_mla.py b/tilert/models/deepseek_v3_2/ops/flash_sparse_mla.py
similarity index 88%
rename from python/models/deepseek_v3_2/ops/flash_sparse_mla.py
rename to tilert/models/deepseek_v3_2/ops/flash_sparse_mla.py
index deebddc..4513b5f 100644
--- a/python/models/deepseek_v3_2/ops/flash_sparse_mla.py
+++ b/tilert/models/deepseek_v3_2/ops/flash_sparse_mla.py
@@ -1,12 +1,12 @@
"""Flash Sparse MLA operation module."""
import math
+from enum import Enum
import torch
from tilert.models.base import TileRTModule
from tilert.models.deepseek_v3_2.model_args import ModelArgs
-from tilert.profiler.utils import parse_profile_log_tensor
from tilert.utils import get_profile_log_tensor
__all__ = [
@@ -25,6 +25,9 @@ def flash_sparse_mla(
output: torch.Tensor,
profile_logs: torch.Tensor,
split_size: int = 64,
+ compute_kernel_type: str = "bf16mma",
+ *,
+ model_arch: str,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Flash Sparse MLA operation for GLM5.
@@ -60,7 +63,7 @@ def flash_sparse_mla(
acc_type = torch.float32
dim = key_value.shape[-1]
- max_num_splits = 32 # topk / split_size = 2048/64
+ max_num_splits = 32
lse = torch.empty((batch, seqlen, heads), device=device, dtype=acc_type)
lse_acc = torch.empty((batch, seqlen, heads, max_num_splits), device=device, dtype=acc_type)
@@ -68,44 +71,42 @@ def flash_sparse_mla(
batch, seqlen, heads, max_num_splits, dim, device=device, dtype=acc_type
)
- if heads == 16:
- torch.ops.tilert.flash_sparse_mla_op(
- query,
- query_pe,
- key_value,
- key_pe,
- indices,
- cur_pos,
- output,
- output_acc,
- lse,
- lse_acc,
- profile_logs,
- split_size,
- )
- elif heads == 8:
- torch.ops.tilert.flash_sparse_mla_glm5_op(
- query,
- query_pe,
- key_value,
- key_pe,
- indices,
- cur_pos,
- output,
- output_acc,
- lse,
- lse_acc,
- profile_logs,
- split_size,
- )
- else:
+ if heads not in (8, 10, 16, 20):
raise ValueError(f"Unsupported heads: {heads}")
+ torch.ops.tilert.flash_sparse_mla_op(
+ query,
+ query_pe,
+ key_value,
+ key_pe,
+ indices,
+ cur_pos,
+ output,
+ output_acc,
+ lse,
+ lse_acc,
+ split_size,
+ model_arch,
+ compute_kernel_type,
+ profile_logs,
+ torch.empty(0, dtype=torch.int64, device=query.device),
+ )
return lse, lse_acc, output_acc
+class FlashSparseMLACombineAlgorithm(Enum):
+ """FlashSparseMLACombine algorithm."""
+
+ BF16MMA = "bf16mma"
+
+
class FlashSparseMLACombine(TileRTModule):
"""Flash Sparse MLA combine module; no weights, uses model_args for scale and config."""
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [FlashSparseMLACombineAlgorithm.BF16MMA],
+ "glm_5": [FlashSparseMLACombineAlgorithm.BF16MMA],
+ }
+
def __init__(
self,
model_args: ModelArgs,
@@ -239,13 +240,8 @@ def tilert_forward(
cur_pos,
output,
self.profile_logs,
+ model_arch=self.model_args.arch_name,
)
- if self.flag_enable_profiling_log:
- # TODO: bug fix for this
- torch.cuda.synchronize()
- parse_profile_log_tensor(
- self.profile_logs, self.get_profile_log_path(), [(self.op_name, 0.0)]
- )
return output
def to_tilert_weights(self) -> None:
diff --git a/python/models/deepseek_v3_2/ops/layernorm_rope_rotate.py b/tilert/models/deepseek_v3_2/ops/layernorm_rope_rotate.py
similarity index 92%
rename from python/models/deepseek_v3_2/ops/layernorm_rope_rotate.py
rename to tilert/models/deepseek_v3_2/ops/layernorm_rope_rotate.py
index b1bd0a7..ae6e6c1 100644
--- a/python/models/deepseek_v3_2/ops/layernorm_rope_rotate.py
+++ b/tilert/models/deepseek_v3_2/ops/layernorm_rope_rotate.py
@@ -1,6 +1,7 @@
"""Layernorm_rope_rotate operation module."""
from dataclasses import dataclass
+from enum import Enum
import torch
import torch.nn.functional as F
@@ -9,7 +10,6 @@
from tilert.models.deepseek_v3_2.model_args import ModelArgs
from tilert.models.deepseek_v3_2.ops.rotate import rotate_activation
from tilert.models.utils import apply_rotary_emb
-from tilert.profiler.utils import parse_profile_log_tensor
from tilert.utils import get_profile_log_tensor
__all__ = [
@@ -28,6 +28,8 @@ def layernorm_rope_rotate(
bias: torch.Tensor,
freqs_cis: torch.Tensor,
profile_logs: torch.Tensor,
+ model_arch: str,
+ compute_kernel_type: str = "general",
) -> None:
"""
Layernorm_rope_rotate operation.
@@ -69,7 +71,15 @@ def layernorm_rope_rotate(
raise ValueError("batch must be 1 in this version")
torch.ops.tilert.layernorm_rope_rotate_op(
- input_raw, cur_pos, k_cache_raw, weight, bias, freqs_cis, profile_logs
+ input_raw,
+ cur_pos,
+ k_cache_raw,
+ weight,
+ bias,
+ freqs_cis,
+ model_arch,
+ compute_kernel_type,
+ profile_logs,
)
@@ -103,9 +113,20 @@ def __call__(self) -> list[str]:
return self.tilert_tensor_alias
+class LayerNormRoPERotateAlgorithm(Enum):
+ """LayerNormRoPERotate algorithm."""
+
+ GENERAL = "general"
+
+
class LayerNormRoPERotate(TileRTModule):
"""LayerNormRoPERotate module: LayerNorm + RoPE + rotate on K indexer output."""
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [LayerNormRoPERotateAlgorithm.GENERAL],
+ "glm_5": [LayerNormRoPERotateAlgorithm.GENERAL],
+ }
+
def __init__(
self,
model_args: ModelArgs,
@@ -194,7 +215,7 @@ def golden_forward(self, idx_k: torch.Tensor, freqs_cis: torch.Tensor) -> torch.
k_pe, k_nope = torch.split(
k, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
)
- k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis).squeeze(2)
+ k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis, interleaved=False).squeeze(2)
k = torch.cat([k_pe, k_nope], dim=-1)
return rotate_activation(k)
@@ -212,11 +233,8 @@ def tilert_forward(self, idx_k: torch.Tensor, freqs_cis: torch.Tensor) -> torch.
self.tilert_bias,
rope_freqs,
self.profile_logs,
+ model_arch=self.model_args.arch_name,
)
- if self.flag_enable_profiling_log:
- parse_profile_log_tensor(
- self.profile_logs, self.get_profile_log_path(), [(self.op_name, 0.0)]
- )
return self.output
def __call__(self, idx_k: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
diff --git a/tilert/models/deepseek_v3_2/ops/padded_allreduce_add.py b/tilert/models/deepseek_v3_2/ops/padded_allreduce_add.py
new file mode 100644
index 0000000..0ea3221
--- /dev/null
+++ b/tilert/models/deepseek_v3_2/ops/padded_allreduce_add.py
@@ -0,0 +1,147 @@
+"""PaddedAllReduceAdd operation module."""
+
+from enum import Enum
+
+import torch
+
+from tilert.models.base import TileRTModule
+from tilert.models.deepseek_v3_2.model_args import ModelArgs
+from tilert.utils import get_profile_log_tensor
+
+__all__ = [
+ "padded_allreduce_add",
+ "PaddedAllReduceAdd",
+]
+
+
+def padded_allreduce_add(
+ partial_buf: torch.Tensor,
+ x_in: torch.Tensor,
+ flag: int,
+ vec_out: torch.Tensor,
+ profile_logs: torch.Tensor,
+ model_arch: str,
+ compute_kernel_type: str = "bf16",
+) -> None:
+ """Padded AllReduce + residual add for Device Group A (GPU 0).
+
+ GPU 0 contributes zeros to the 8-GPU AllReduce, then adds the residual.
+
+ Args:
+ partial_buf: Zero-filled partial buffer [1, L, hidden_dim] bf16.
+ x_in: Residual input [1, L, hidden_dim] bf16.
+ flag: AllReduce sync flag.
+ vec_out: Output tensor [1, L, hidden_dim] bf16.
+ profile_logs: Profile logs tensor.
+ model_arch: Model architecture ("deepseek_v3_2" or "glm_5").
+ compute_kernel_type: Compute kernel type ("bf16").
+ """
+ torch.ops.tilert.padded_allreduce_add_op(
+ partial_buf, x_in, flag, vec_out, profile_logs, model_arch, compute_kernel_type
+ )
+
+
+class PaddedAllReduceAddAlgorithm(Enum):
+ """PaddedAllReduceAdd algorithm."""
+
+ BF16 = "bf16"
+
+
+class PaddedAllReduceAdd(TileRTModule):
+ """PaddedAllReduceAdd module — zero-partial AllReduce + residual add."""
+
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [PaddedAllReduceAddAlgorithm.BF16],
+ "glm_5": [PaddedAllReduceAddAlgorithm.BF16],
+ }
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ num_devices: int,
+ device_id: int = 0,
+ ):
+ super().__init__(
+ self.__class__.__name__,
+ model_args=model_args,
+ num_devices=num_devices,
+ device_id=device_id,
+ )
+
+ self.dim = self.model_args.dim
+
+ self.partial_buf: torch.Tensor | None = None
+
+ self.hidden_out: torch.Tensor | None = None
+
+ self.profile_logs: torch.Tensor | None = None
+ self.is_var_init = False
+
+ def init_tilert_vars(self, batch_size: int, seq_len: int) -> None:
+ """Allocate output buffer and persistent zero-filled partial buffer.
+
+ Args:
+ batch_size: Batch size.
+ seq_len: Sequence length.
+ """
+ self.hidden_out = torch.zeros(
+ (batch_size, seq_len, self.dim),
+ dtype=torch.bfloat16,
+ device=f"cuda:{self.device_id}",
+ )
+ self.partial_buf = torch.zeros(
+ (batch_size, seq_len, self.dim),
+ dtype=torch.bfloat16,
+ device=f"cuda:{self.device_id}",
+ )
+ self.profile_logs = get_profile_log_tensor(device=f"cuda:{self.device_id}")
+ self.is_var_init = True
+
+ def golden_forward(
+ self,
+ x_in: torch.Tensor,
+ ) -> torch.Tensor:
+ """Golden reference: allreduce(zeros) + x_in = x_in (single-GPU).
+
+ On a single GPU, allreduce of zeros returns zeros, so output = x_in.
+
+ Args:
+ x_in: Residual input [1, L, hidden_dim].
+
+ Returns:
+ Output tensor (copy of x_in).
+ """
+ return x_in.clone()
+
+ def tilert_forward(
+ self,
+ x_in: torch.Tensor,
+ flag: int,
+ ) -> torch.Tensor:
+ """Run TileRT kernel forward.
+
+ Args:
+ x_in: Residual input [1, L, hidden_dim].
+ flag: AllReduce sync flag.
+
+ Returns:
+ Output tensor [1, L, hidden_dim].
+ """
+ assert self.hidden_out is not None
+ assert self.partial_buf is not None
+ assert self.profile_logs is not None
+ padded_allreduce_add(
+ self.partial_buf,
+ x_in,
+ flag,
+ self.hidden_out,
+ self.profile_logs,
+ model_arch=self.model_args.arch_name,
+ )
+ return self.hidden_out
+
+ def __call__(
+ self,
+ x_in: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.golden_forward(x_in)
diff --git a/tilert/models/deepseek_v3_2/ops/projo_wkvb.py b/tilert/models/deepseek_v3_2/ops/projo_wkvb.py
new file mode 100644
index 0000000..845bd60
--- /dev/null
+++ b/tilert/models/deepseek_v3_2/ops/projo_wkvb.py
@@ -0,0 +1,483 @@
+"""ProjOWkvb operation module."""
+
+import math
+from dataclasses import dataclass
+from enum import Enum
+
+import torch
+
+from tilert.models.base import TileRTModule, TilertWeightsConverter
+from tilert.models.common import init_func, weight_dequant
+from tilert.models.deepseek_v3_2.model_args import ModelArgs
+from tilert.utils import get_profile_log_tensor
+
+__all__ = [
+ "projo_wkvb",
+ "ProjoWKVb",
+ "ProjoWKVbAlgorithm",
+ "ProjoWKVbWeightsConverter",
+ "ProjoWKVbRefWeightsAlias",
+ "ProjoWKVbTilertWeightsAlias",
+]
+
+
+def projo_wkvb(
+ o_in: torch.Tensor,
+ wkv_b_b: torch.Tensor,
+ wkv_b_scales: torch.Tensor,
+ output: torch.Tensor,
+ profile_logs: torch.Tensor,
+ model_arch: str,
+ compute_kernel_type: str = "fp16mma",
+) -> None:
+ """
+ Define the ProjOWkvb operation.
+
+ Args:
+ o_in: Input tensor.
+ wkv_b_b: Weight tensor.
+ wkv_b_scales: Scale tensor.
+ output: Output tensor.
+ profile_logs: Profile logs tensor.
+ model_arch: Model architecture ("deepseek_v3_2" or "glm_5").
+ compute_kernel_type: Kernel type ("fp16mma" for both DSv32 and GLM5).
+ """
+ torch.ops.tilert.projo_wkvb_op(
+ o_in,
+ wkv_b_b,
+ wkv_b_scales,
+ output,
+ model_arch,
+ compute_kernel_type,
+ profile_logs,
+ torch.empty(0, dtype=torch.int64, device=o_in.device),
+ )
+
+
+class ProjoWKVbAlgorithm(Enum):
+ """ProjoWKVb algorithm"""
+
+ GENERAL = "general"
+ FP16MMA = "fp16mma"
+ BF16MMA = "bf16mma"
+
+
+class ProjoWKVbWeightsConverter(TilertWeightsConverter):
+ def __init__(self, model_args: ModelArgs, num_devices: int):
+ super().__init__(model_args, num_devices)
+
+ @staticmethod
+ def _swizzle_mma_16x16(mat_in: torch.Tensor) -> torch.Tensor:
+ """Swizzle a [*, 16, 16] sub-block for the MMA kernel."""
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 16
+ pre_shape = mat_in.shape[:-2]
+ mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 2).transpose(-4, -3).transpose(-5, -4)
+ return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 2).transpose(-3, -2)
+
+ @staticmethod
+ def _swizzle_mma_16x16_for_pages(mat_in: torch.Tensor, k_dim: int, pages: int) -> torch.Tensor:
+ """Swizzle [*, 16, K] matrix for paged MMA layout."""
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == k_dim
+ pre_shape = mat_in.shape[:-2]
+ k_per_page = k_dim // pages
+ n_k_tiles = k_per_page // 16
+ mat_in = mat_in.reshape(*pre_shape, 16, pages, k_per_page).transpose(-3, -2)
+ mat_in = mat_in.reshape(*pre_shape, pages, 16, n_k_tiles, 16).transpose(-3, -2)
+ mat_in = ProjoWKVbWeightsConverter._swizzle_mma_16x16(mat_in)
+ return mat_in.contiguous()
+
+ def convert_to_fp16mma(self, weights: list[torch.Tensor]) -> torch.Tensor:
+ """Convert weights to the FP16 MMA packed format."""
+ with torch.inference_mode():
+ wkv_b_b, wkv_b_b_scales = self.convert_to_general(weights)
+
+ n_heads = wkv_b_b.size(0)
+ v_head_dim = wkv_b_b.size(1)
+ kv_lora_rank = wkv_b_b.size(2)
+ num_ctas = 80
+ rows_per_cta = (n_heads * v_head_dim) // num_ctas
+
+ is_glm5 = self.model_args.arch_name == "glm_5"
+
+ w_flat = wkv_b_b.reshape(num_ctas, rows_per_cta // 16, 16, kv_lora_rank)
+ w_swizzled = ProjoWKVbWeightsConverter._swizzle_mma_16x16_for_pages(
+ w_flat, kv_lora_rank, pages=1
+ )
+ w_bytes = w_swizzled.reshape(num_ctas, -1)
+
+ scale_k_block = 128
+ n_scale_k = kv_lora_rank // scale_k_block
+ ctas_per_head = num_ctas // n_heads
+
+ if is_glm5:
+ ctas_per_scale_row = 64 // rows_per_cta
+ scales_per_cta = wkv_b_b_scales.repeat_interleave(ctas_per_scale_row, dim=1)
+ scales_per_cta = scales_per_cta.reshape(num_ctas, n_scale_k)
+ else:
+ scales_per_cta = wkv_b_b_scales.squeeze(1).repeat_interleave(ctas_per_head, dim=0)
+
+ scale_dtype = torch.float32
+ scales_per_cta = scales_per_cta.to(scale_dtype)
+
+ mat_bytes = rows_per_cta * kv_lora_rank
+ scale_bytes = n_scale_k * 4
+ page_size = (mat_bytes + scale_bytes + 127) // 128 * 128
+
+ scales_raw = scales_per_cta.contiguous().view(torch.float8_e4m3fn)
+ padding_size = page_size - mat_bytes - scales_raw.shape[-1]
+ padding = torch.zeros(
+ num_ctas, padding_size, dtype=torch.float8_e4m3fn, device=wkv_b_b.device
+ )
+ return torch.cat([w_bytes, scales_raw, padding], dim=-1).contiguous()
+
+ def convert_to_bf16mma(self, weights: list[torch.Tensor]) -> torch.Tensor:
+ """Convert weights to the BF16 MMA packed format."""
+ with torch.inference_mode():
+ tilert_wkv_b_weights, tilert_wkv_b_scales = weights
+
+ wkvb_head_dim = self.model_args.qk_nope_head_dim + self.model_args.v_head_dim
+ left_head_dim = wkvb_head_dim % self.model_args.block_size
+ hd_block = left_head_dim if left_head_dim != 0 else self.model_args.block_size
+
+ if self.model_args.n_heads % self.num_devices == 0:
+ n_local_heads = self.model_args.n_heads // self.num_devices
+ else:
+ n_local_heads = math.ceil(self.model_args.n_heads / self.num_devices)
+ if n_local_heads % 2 != 0:
+ n_local_heads += 1
+
+ v_head_dim = self.model_args.v_head_dim
+ kv_lora_rank = self.model_args.kv_lora_rank
+ n_block = self.model_args.block_size
+
+ w = tilert_wkv_b_weights
+ s = tilert_wkv_b_scales
+ if self.model_args.n_heads % self.num_devices != 0:
+ n_current = w.size(0)
+ if n_current < n_local_heads:
+ pad_w = torch.zeros(
+ n_local_heads - n_current, *w.shape[1:], dtype=w.dtype, device=w.device
+ )
+ w = torch.cat([w, pad_w], dim=0)
+ pad_s = torch.zeros(
+ n_local_heads - n_current, *s.shape[1:], dtype=s.dtype, device=s.device
+ )
+ s = torch.cat([s, pad_s], dim=0)
+
+ s = s.float()
+ s = s.repeat_interleave(hd_block, dim=1).repeat_interleave(n_block, dim=2)
+ wkv_bf16 = (w.float() * s).to(torch.bfloat16)
+ n_heads = n_local_heads
+
+ num_ctas = 80
+ rows_per_cta = (n_heads * v_head_dim) // num_ctas
+
+ w_flat = wkv_bf16.reshape(num_ctas, rows_per_cta // 16, 16, kv_lora_rank)
+ w_swizzled = ProjoWKVbWeightsConverter._swizzle_mma_16x16_for_pages(
+ w_flat, kv_lora_rank, pages=1
+ )
+ w_bytes = w_swizzled.reshape(num_ctas, -1).contiguous().view(torch.float8_e4m3fn)
+
+ mat_bytes = rows_per_cta * kv_lora_rank * 2
+ page_size = (mat_bytes + 127) // 128 * 128
+ padding_size = page_size - w_bytes.shape[-1]
+
+ if padding_size > 0:
+ padding = torch.zeros(
+ num_ctas, padding_size, dtype=torch.float8_e4m3fn, device=wkv_bf16.device
+ )
+ return torch.cat([w_bytes, padding], dim=-1).contiguous()
+ return w_bytes.contiguous()
+
+ def convert_to_general(self, weights: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
+ with torch.inference_mode():
+ tilert_wkv_b_weights, tilert_wkv_b_scales = weights
+
+ wkv_b_b = tilert_wkv_b_weights.contiguous()
+ wkv_b_b_scales = tilert_wkv_b_scales.contiguous()
+ if self.model_args.arch_name == "glm_5":
+ if wkv_b_b_scales.dtype != torch.float32:
+ print(
+ "Warning: ProjoWKVbWeightsConverter: "
+ + f"wkv_b_b_scales.dtype: {wkv_b_b_scales.dtype} "
+ + "is not float32, convert to float32."
+ )
+ wkv_b_b_scales = wkv_b_b_scales.to(torch.float32)
+ else:
+ wkv_b_b_scales = wkv_b_b_scales.to(torch.bfloat16)
+
+ wkv_b_b = wkv_b_b.detach()
+ wkv_b_b_scales = wkv_b_b_scales.detach()
+
+ if self.model_args.n_heads % self.num_devices != 0:
+ n_target = math.ceil(self.model_args.n_heads / self.num_devices)
+ if n_target % 2 != 0:
+ n_target += 1
+ n_current = wkv_b_b.size(0)
+ if n_current < n_target:
+ pad_b = torch.zeros(
+ n_target - n_current,
+ *wkv_b_b.shape[1:],
+ dtype=wkv_b_b.dtype,
+ device=wkv_b_b.device,
+ )
+ wkv_b_b = torch.cat([wkv_b_b, pad_b], dim=0)
+ pad_s = torch.zeros(
+ n_target - n_current,
+ *wkv_b_b_scales.shape[1:],
+ dtype=wkv_b_b_scales.dtype,
+ device=wkv_b_b_scales.device,
+ )
+ wkv_b_b_scales = torch.cat([wkv_b_b_scales, pad_s], dim=0)
+ wkv_b_b = wkv_b_b.contiguous()
+ wkv_b_b_scales = wkv_b_b_scales.contiguous()
+
+ return wkv_b_b, wkv_b_b_scales
+
+
+@dataclass
+class ProjoWKVbRefWeightsAlias:
+ """Reference weights alias for ProjoWKVb."""
+
+ wkv_b_weights = "self_attn.kv_b_proj.weight"
+ wkv_b_scales = "self_attn.kv_b_proj.weight_scale_inv"
+
+ @property
+ def ref_tensor_alias(self) -> list[str]:
+ return [self.wkv_b_weights, self.wkv_b_scales]
+
+ def __call__(self) -> list[str]:
+ return self.ref_tensor_alias
+
+
+@dataclass
+class ProjoWKVbTilertWeightsAlias:
+ """TileRT weights alias for ProjoWKVb."""
+
+ wkv_b_weights = "wkv_b2_weights"
+ wkv_b_scales = "wkv_b2_scales"
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return [self.wkv_b_weights, self.wkv_b_scales]
+
+ def __call__(self) -> list[str]:
+ return self.tilert_tensor_alias
+
+
+class ProjoWKVb(TileRTModule):
+ """ProjoWKVb module: O projection (wkv_b) for output."""
+
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [ProjoWKVbAlgorithm.FP16MMA],
+ "glm_5": [ProjoWKVbAlgorithm.FP16MMA],
+ }
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ num_devices: int,
+ device_id: int = 0,
+ ref_weights_alias: ProjoWKVbRefWeightsAlias | None = None,
+ ):
+ super().__init__(
+ self.__class__.__name__,
+ model_args=model_args,
+ num_devices=num_devices,
+ device_id=device_id,
+ )
+
+ self.tilert_weights_alias = ProjoWKVbTilertWeightsAlias()
+ self.ref_weights_alias = (
+ ref_weights_alias if ref_weights_alias is not None else ProjoWKVbRefWeightsAlias()
+ )
+
+ self.ref_wkv_b: torch.Tensor | None = None
+ self.tilert_wkv_b_b: torch.Tensor | None = None
+ self.tilert_wkv_b_b_scales: torch.Tensor | None = None
+ self.output: torch.Tensor | None = None
+ self.profile_logs: torch.Tensor | None = None
+
+ if self.model_args.n_heads % self.num_devices == 0:
+ self.num_local_heads = self.model_args.n_heads // self.num_devices
+ else:
+ n_local = math.ceil(self.model_args.n_heads / self.num_devices)
+ if n_local % 2 != 0:
+ n_local += 1
+ self.num_local_heads = n_local
+
+ self.wkvb_lora_rank = self.model_args.kv_lora_rank
+ self.wkvb_lora_rank_qsize = self.wkvb_lora_rank // self.model_args.block_size
+
+ self.wkvb_head_dim = self.model_args.qk_nope_head_dim + self.model_args.v_head_dim
+ self.wkvb_v_head_dim = self.model_args.v_head_dim
+ left_head_dim = self.wkvb_head_dim % self.model_args.block_size
+ if left_head_dim != 0:
+ assert self.model_args.block_size % left_head_dim == 0
+ self.head_dim_block_size = left_head_dim
+ self.head_dim_scale_repeat = self.model_args.block_size // self.head_dim_block_size
+ else:
+ self.head_dim_scale_repeat = 1
+ self.head_dim_block_size = self.model_args.block_size
+ self.wkvb_head_qsize = self.wkvb_head_dim // self.head_dim_block_size
+ self.wkvb_v_head_qsize = self.wkvb_v_head_dim // self.head_dim_block_size
+
+ self.compute_kernel_type = "fp16mma"
+
+ def get_weights_list(self) -> list[torch.Tensor]:
+ return [self.tilert_wkv_b_b, self.tilert_wkv_b_b_scales]
+
+ def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ """
+ Device sharding: split weights and scales per device.
+
+ Args:
+ weights_map: Map from ref weight alias to tensor.
+
+ Returns:
+ Map from tilert weight alias to (num_devices, ...) tensors.
+ """
+ kv_b_proj_weight = weights_map[self.ref_weights_alias.wkv_b_weights]
+ kv_b_proj_weight_scale = weights_map[self.ref_weights_alias.wkv_b_scales]
+
+ if self.model_args.n_heads % self.num_devices == 0:
+ dev_weights = kv_b_proj_weight.view(
+ self.num_devices, self.num_local_heads, self.wkvb_head_dim, self.wkvb_lora_rank
+ )
+ dev_scale_rows = self.num_local_heads * self.wkvb_head_dim // self.model_args.block_size
+ dev_scales = kv_b_proj_weight_scale.view(
+ self.num_devices, dev_scale_rows, 1, self.wkvb_lora_rank_qsize
+ )
+ else:
+ from tilert.models.deepseek_v3_2.ops.rmsnorm_projq_wqb import (
+ RmsnormProjqWqbWeightsConverter,
+ )
+
+ wq_b_list, scale_list = RmsnormProjqWqbWeightsConverter._redistribute_heads(
+ kv_b_proj_weight,
+ kv_b_proj_weight_scale,
+ n_total_heads=self.model_args.n_heads,
+ n_local_heads=self.num_local_heads,
+ num_devices=self.num_devices,
+ qk_head_dim=self.wkvb_head_dim,
+ block_size=self.model_args.block_size,
+ )
+ dev_weights = torch.stack(wq_b_list, dim=0).view(
+ self.num_devices, self.num_local_heads, self.wkvb_head_dim, self.wkvb_lora_rank
+ )
+ dev_scale_rows = self.num_local_heads * self.wkvb_head_dim // self.model_args.block_size
+ dev_scales = torch.stack(scale_list, dim=0).view(
+ self.num_devices, dev_scale_rows, 1, self.wkvb_lora_rank_qsize
+ )
+
+ wkvb = dev_weights[:, :, -self.wkvb_v_head_dim :]
+ wkvb_scales = (
+ dev_scales.contiguous()
+ .repeat(1, 1, self.head_dim_scale_repeat, 1)
+ .view(
+ self.num_devices,
+ self.num_local_heads,
+ self.wkvb_head_qsize,
+ self.wkvb_lora_rank_qsize,
+ )
+ .contiguous()[:, :, -self.wkvb_v_head_qsize :]
+ )
+ return {
+ self.tilert_weights_alias.wkv_b_weights: wkvb.contiguous(),
+ self.tilert_weights_alias.wkv_b_scales: wkvb_scales.contiguous(),
+ }
+
+ def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ sharding_size = self.num_local_heads * self.wkvb_head_dim
+ sharding_start = self.device_id * sharding_size
+ sharding_end = sharding_start + sharding_size
+ wkv_b = weight_dequant(
+ state_dict[self.ref_weights_alias.wkv_b_weights],
+ state_dict[self.ref_weights_alias.wkv_b_scales],
+ )
+ wkv_b = wkv_b[sharding_start:sharding_end, :]
+ wkv_b = wkv_b.view(self.num_local_heads, self.wkvb_head_dim, self.wkvb_lora_rank)
+ self.ref_wkv_b = wkv_b[:, -self.wkvb_v_head_dim :]
+
+ def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ self.init_tilert_weights_hmma(state_dict)
+
+ def init_tilert_weights_hmma(self, state_dict: dict[str, torch.Tensor]) -> None:
+ """Initialize with HMMA-packed weights."""
+ packed = ProjoWKVbWeightsConverter(self.model_args, self.num_devices).dispatch(
+ ProjoWKVbAlgorithm.FP16MMA,
+ [
+ state_dict[self.tilert_weights_alias.wkv_b_weights],
+ state_dict[self.tilert_weights_alias.wkv_b_scales],
+ ],
+ )
+ self.tilert_wkv_b_b = packed
+ self.tilert_wkv_b_b_scales = torch.empty(1, dtype=torch.float8_e4m3fn, device=packed.device)
+ self.compute_kernel_type = "fp16mma"
+
+ def init_tilert_weights_hmma_bf16(self, state_dict: dict[str, torch.Tensor]) -> None:
+ """Initialize with BF16 HMMA-packed weights (dequantized, no scales)."""
+ packed = ProjoWKVbWeightsConverter(self.model_args, self.num_devices).dispatch(
+ ProjoWKVbAlgorithm.BF16MMA,
+ [
+ state_dict[self.tilert_weights_alias.wkv_b_weights],
+ state_dict[self.tilert_weights_alias.wkv_b_scales],
+ ],
+ )
+ self.tilert_wkv_b_b = packed
+ self.tilert_wkv_b_b_scales = torch.empty(1, dtype=torch.float8_e4m3fn, device=packed.device)
+ self.compute_kernel_type = "bf16mma"
+
+ def init_random_weights(self) -> None:
+ padded_total_heads = self.num_local_heads * self.num_devices
+ wkv_b = init_func(
+ torch.empty(
+ padded_total_heads * self.wkvb_head_dim,
+ self.wkvb_lora_rank,
+ dtype=torch.float8_e4m3fn,
+ )
+ )
+ wkv_b_scales = init_func(
+ torch.empty(
+ padded_total_heads * self.wkvb_head_dim // self.model_args.block_size,
+ self.wkvb_lora_rank_qsize,
+ dtype=torch.float32,
+ )
+ )
+ ref_state_dict = dict(
+ zip(
+ self.ref_weights_alias(),
+ [wkv_b, wkv_b_scales],
+ )
+ )
+ self.init_reference_weights(ref_state_dict)
+ sharded = self.device_sharding(ref_state_dict)
+ self.init_tilert_weights({k: v[self.device_id] for k, v in sharded.items()})
+
+ def init_tilert_vars(self, batch_size: int, seq_len: int) -> None:
+ self.output = torch.zeros(
+ (batch_size, seq_len, self.num_local_heads, self.wkvb_v_head_dim),
+ dtype=torch.bfloat16,
+ )
+ self.profile_logs = get_profile_log_tensor()
+ self.is_var_init = True
+
+ def golden_forward(self, x_out: torch.Tensor) -> torch.Tensor:
+ assert self.ref_wkv_b is not None
+ return torch.einsum("bshc,hdc->bshd", x_out, self.ref_wkv_b)
+
+ def tilert_forward(self, x_out: torch.Tensor) -> torch.Tensor:
+ assert self.tilert_wkv_b_b is not None
+ assert self.tilert_wkv_b_b_scales is not None
+ assert self.output is not None
+ assert self.profile_logs is not None
+ projo_wkvb(
+ x_out,
+ self.tilert_wkv_b_b,
+ self.tilert_wkv_b_b_scales,
+ self.output,
+ self.profile_logs,
+ model_arch=self.model_args.arch_name,
+ compute_kernel_type=self.compute_kernel_type,
+ )
+ return self.output
diff --git a/python/models/deepseek_v3_2/ops/projq_wqb.py b/tilert/models/deepseek_v3_2/ops/projq_wqb.py
similarity index 51%
rename from python/models/deepseek_v3_2/ops/projq_wqb.py
rename to tilert/models/deepseek_v3_2/ops/projq_wqb.py
index 7287aa2..bc2bc12 100644
--- a/python/models/deepseek_v3_2/ops/projq_wqb.py
+++ b/tilert/models/deepseek_v3_2/ops/projq_wqb.py
@@ -1,5 +1,6 @@
"""ProjQB operation module."""
+import math
from dataclasses import dataclass
from enum import Enum
@@ -8,7 +9,6 @@
from tilert.models.base import TileRTModule, TilertWeightsConverter
from tilert.models.common import init_func, weight_dequant
from tilert.models.deepseek_v3_2.model_args import ModelArgs
-from tilert.profiler.utils import parse_profile_log_tensor
from tilert.utils import get_profile_log_tensor
__all__ = [
@@ -27,6 +27,9 @@ def projq_wqb(
wkv_b_a_scales: torch.Tensor,
output: torch.Tensor,
profile_logs: torch.Tensor,
+ compute_kernel_type: str = "fp16mma",
+ *,
+ model_arch: str,
) -> None:
"""
Define the ProjqWqb operation.
@@ -37,17 +40,26 @@ def projq_wqb(
wkv_b_a_scales: Scale tensor.
output: Output tensor.
profile_logs: Profile logs tensor.
+ compute_kernel_type: Kernel type ("fp16mma").
+ model_arch: Model architecture ("deepseek_v3_2" or "glm_5").
"""
- if q_nope_in.shape[-1] == 128:
- torch.ops.tilert.projq_wqb_op(q_nope_in, wkv_b_a, wkv_b_a_scales, output, profile_logs)
- elif q_nope_in.shape[-1] == 192:
- torch.ops.tilert.proj_qb_glm5_op(q_nope_in, wkv_b_a, wkv_b_a_scales, output, profile_logs)
+ torch.ops.tilert.projq_wqb_op(
+ q_nope_in,
+ wkv_b_a,
+ wkv_b_a_scales,
+ output,
+ model_arch,
+ compute_kernel_type,
+ profile_logs,
+ )
class ProjqWqbAlgorithm(Enum):
"""ProjqWqb algorithm"""
GENERAL = "general"
+ FP16MMA = "fp16mma"
+ BF16MMA = "bf16mma"
class ProjqWqbWeightsConverter(TilertWeightsConverter):
@@ -56,11 +68,119 @@ def __init__(self, model_args: ModelArgs, num_devices: int, head_dim_block_size:
self.head_dim_block_size = head_dim_block_size
self.impl_block_size = 64
+ @staticmethod
+ def _swizzle_mma_16x16(mat_in: torch.Tensor) -> torch.Tensor:
+ """Swizzle a [*, 16, 16] sub-block for the MMA kernel."""
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 16
+ pre_shape = mat_in.shape[:-2]
+ mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 2).transpose(-4, -3).transpose(-5, -4)
+ return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 2).transpose(-3, -2)
+
+ @staticmethod
+ def _swizzle_mma_16x16_for_pages(mat_in: torch.Tensor, k_dim: int, pages: int) -> torch.Tensor:
+ """Swizzle [*, 16, K] matrix for paged MMA layout."""
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == k_dim
+ pre_shape = mat_in.shape[:-2]
+ k_per_page = k_dim // pages
+ n_k_tiles = k_per_page // 16
+ mat_in = mat_in.reshape(*pre_shape, 16, pages, k_per_page).transpose(-3, -2)
+ mat_in = mat_in.reshape(*pre_shape, pages, 16, n_k_tiles, 16).transpose(-3, -2)
+ mat_in = ProjqWqbWeightsConverter._swizzle_mma_16x16(mat_in)
+ return mat_in.contiguous()
+
+ def convert_to_fp16mma(self, weights: list[torch.Tensor]) -> torch.Tensor:
+ """Convert weights to the FP16 MMA packed format."""
+ with torch.inference_mode():
+ wkv_b_a, wkv_b_a_scales = self.convert_to_general(weights)
+
+ n_heads = wkv_b_a.size(0)
+ head_dim = wkv_b_a.size(2)
+ kv_lora_rank = wkv_b_a.size(1)
+ num_ctas = 80
+ rows_per_cta = (n_heads * kv_lora_rank) // num_ctas
+
+ is_glm5 = self.model_args.arch_name == "glm_5"
+
+ w_flat = wkv_b_a.reshape(num_ctas, rows_per_cta // 16, 16, head_dim)
+ w_swizzled = self._swizzle_mma_16x16_for_pages(w_flat, head_dim, pages=1)
+ w_bytes = w_swizzled.reshape(num_ctas, -1)
+
+ kScalesPerPage = head_dim // 64
+
+ if is_glm5:
+ ctas_per_scale_row = 128 // rows_per_cta
+ scales_expanded = wkv_b_a_scales.repeat_interleave(ctas_per_scale_row, dim=1)
+ scales_per_cta = scales_expanded.reshape(num_ctas, kScalesPerPage)
+ scale_dtype = torch.float32
+ else:
+ scales_per_cta = wkv_b_a_scales.reshape(num_ctas, kScalesPerPage)
+ scale_dtype = torch.bfloat16
+
+ mat_bytes = rows_per_cta * head_dim
+ scale_elem_bytes = 4 if scale_dtype == torch.float32 else 2
+ scale_bytes = kScalesPerPage * scale_elem_bytes
+ page_size = (mat_bytes + scale_bytes + 127) // 128 * 128
+
+ scales_raw = scales_per_cta.to(scale_dtype).contiguous().view(torch.float8_e4m3fn)
+ padding_size = page_size - mat_bytes - scales_raw.shape[-1]
+ padding = torch.zeros(
+ num_ctas, padding_size, dtype=torch.float8_e4m3fn, device=wkv_b_a.device
+ )
+ return torch.cat([w_bytes, scales_raw, padding], dim=-1).contiguous()
+
+ def convert_to_bf16mma(self, weights: list[torch.Tensor]) -> torch.Tensor:
+ """Convert weights to the BF16 MMA packed format."""
+ with torch.inference_mode():
+ tilert_wkv_b_weights, tilert_wkv_b_scales = weights
+
+ if self.model_args.n_heads % self.num_devices == 0:
+ n_local_heads = self.model_args.n_heads // self.num_devices
+ else:
+ n_local_heads = math.ceil(self.model_args.n_heads / self.num_devices)
+ if n_local_heads % 2 != 0:
+ n_local_heads += 1
+
+ nope_head_dim = self.model_args.qk_nope_head_dim
+ kv_lora_rank = self.model_args.kv_lora_rank
+ hd_block = self.head_dim_block_size
+ n_block = self.model_args.block_size
+
+ s = tilert_wkv_b_scales.float()
+ s = s.repeat_interleave(hd_block, dim=1).repeat_interleave(n_block, dim=2)
+ wkv_bf16 = (
+ (tilert_wkv_b_weights.float() * s).transpose(1, 2).contiguous().to(torch.bfloat16)
+ )
+ n_heads = n_local_heads
+ head_dim = nope_head_dim
+
+ num_ctas = 80
+ rows_per_cta = (n_heads * kv_lora_rank) // num_ctas
+
+ w_flat = wkv_bf16.reshape(num_ctas, rows_per_cta // 16, 16, head_dim)
+ w_swizzled = self._swizzle_mma_16x16_for_pages(w_flat, head_dim, pages=1)
+ w_bytes = w_swizzled.reshape(num_ctas, -1).contiguous().view(torch.float8_e4m3fn)
+
+ mat_bytes = rows_per_cta * head_dim * 2
+ page_size = (mat_bytes + 127) // 128 * 128
+ padding_size = page_size - w_bytes.shape[-1]
+
+ if padding_size > 0:
+ padding = torch.zeros(
+ num_ctas, padding_size, dtype=torch.float8_e4m3fn, device=wkv_bf16.device
+ )
+ return torch.cat([w_bytes, padding], dim=-1).contiguous()
+ return w_bytes.contiguous()
+
def convert_to_general(self, weights: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
with torch.inference_mode():
tilert_wkv_b_weights, tilert_wkv_b_scales = weights
- n_local_heads = self.model_args.n_heads // self.num_devices
+ if self.model_args.n_heads % self.num_devices == 0:
+ n_local_heads = self.model_args.n_heads // self.num_devices
+ else:
+ n_local_heads = math.ceil(self.model_args.n_heads / self.num_devices)
+ if n_local_heads % 2 != 0:
+ n_local_heads += 1
wkv_b = tilert_wkv_b_weights
wkv_b_scales_raw = tilert_wkv_b_scales
@@ -84,9 +204,8 @@ def convert_to_general(self, weights: list[torch.Tensor]) -> tuple[torch.Tensor,
+ "is not float32, convert to float32."
)
wkv_b_a_scales = wkv_b_a_scales.to(torch.float32)
- else: # DS v3.2, use bfloat16 for wkv_b_a_scales
+ else:
wkv_b_a_scales = wkv_b_a_scales.to(torch.bfloat16)
- # Tiling to fit tilert input
if self.head_dim_block_size != self.impl_block_size:
repeats = self.head_dim_block_size // self.impl_block_size
wkv_b_a_scales = wkv_b_a_scales.repeat(1, 1, repeats).contiguous()
@@ -130,6 +249,11 @@ def __call__(self) -> list[str]:
class ProjqWqb(TileRTModule):
"""ProjqWqb module: Q projection (wkv_b) for KV LoRA."""
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [ProjqWqbAlgorithm.FP16MMA],
+ "glm_5": [ProjqWqbAlgorithm.FP16MMA],
+ }
+
def __init__(
self,
model_args: ModelArgs,
@@ -155,9 +279,16 @@ def __init__(
self.output: torch.Tensor | None = None
self.profile_logs: torch.Tensor | None = None
- self.num_local_heads = self.model_args.n_heads // self.num_devices
+ self.compute_kernel_type = "fp16mma"
+
+ if self.model_args.n_heads % self.num_devices == 0:
+ self.num_local_heads = self.model_args.n_heads // self.num_devices
+ else:
+ n_local = math.ceil(self.model_args.n_heads / self.num_devices)
+ if n_local % 2 != 0:
+ n_local += 1
+ self.num_local_heads = n_local
- # lora dim and quant block size
self.wkvb_lora_rank = self.model_args.kv_lora_rank
self.wkvb_lora_rank_qsize = self.wkvb_lora_rank // self.model_args.block_size
@@ -194,18 +325,39 @@ def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, tor
kv_b_proj_weight = weights_map[self.ref_weights_alias.wkv_b_weights]
kv_b_proj_weight_scale = weights_map[self.ref_weights_alias.wkv_b_scales]
- dev_heads = (self.num_devices, self.num_local_heads)
- wkvb = kv_b_proj_weight.view(*dev_heads, self.wkvb_head_dim, self.wkvb_lora_rank)[
- :, :, : self.wkvb_nope_head_dim
- ]
- wkvb_scales = (
- kv_b_proj_weight_scale.view(
- self.num_devices,
- self.num_local_heads * self.wkvb_head_dim // self.model_args.block_size,
- 1,
- self.wkvb_lora_rank_qsize,
+ if self.model_args.n_heads % self.num_devices == 0:
+ dev_weights = kv_b_proj_weight.view(
+ self.num_devices, self.num_local_heads, self.wkvb_head_dim, self.wkvb_lora_rank
+ )
+ dev_scale_rows = self.num_local_heads * self.wkvb_head_dim // self.model_args.block_size
+ dev_scales = kv_b_proj_weight_scale.view(
+ self.num_devices, dev_scale_rows, 1, self.wkvb_lora_rank_qsize
)
- .contiguous()
+ else:
+ from tilert.models.deepseek_v3_2.ops.rmsnorm_projq_wqb import (
+ RmsnormProjqWqbWeightsConverter,
+ )
+
+ wq_b_list, scale_list = RmsnormProjqWqbWeightsConverter._redistribute_heads(
+ kv_b_proj_weight,
+ kv_b_proj_weight_scale,
+ n_total_heads=self.model_args.n_heads,
+ n_local_heads=self.num_local_heads,
+ num_devices=self.num_devices,
+ qk_head_dim=self.wkvb_head_dim,
+ block_size=self.model_args.block_size,
+ )
+ dev_weights = torch.stack(wq_b_list, dim=0).view(
+ self.num_devices, self.num_local_heads, self.wkvb_head_dim, self.wkvb_lora_rank
+ )
+ dev_scale_rows = self.num_local_heads * self.wkvb_head_dim // self.model_args.block_size
+ dev_scales = torch.stack(scale_list, dim=0).view(
+ self.num_devices, dev_scale_rows, 1, self.wkvb_lora_rank_qsize
+ )
+
+ wkvb = dev_weights[:, :, : self.wkvb_nope_head_dim]
+ wkvb_scales = (
+ dev_scales.contiguous()
.repeat(1, 1, self.head_dim_scale_repeat, 1)
.view(
self.num_devices,
@@ -233,29 +385,50 @@ def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
self.ref_wkv_b = wkv_b[:, : self.wkvb_nope_head_dim]
def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
- self.tilert_wkv_b_a, self.tilert_wkv_b_a_scales = ProjqWqbWeightsConverter(
+ self.init_tilert_weights_hmma(state_dict)
+
+ def init_tilert_weights_hmma(self, state_dict: dict[str, torch.Tensor]) -> None:
+ """Initialize with HMMA-packed weights."""
+ packed = ProjqWqbWeightsConverter(
self.model_args, self.num_devices, self.head_dim_block_size
).dispatch(
- ProjqWqbAlgorithm.GENERAL,
+ ProjqWqbAlgorithm.FP16MMA,
[
state_dict[self.tilert_weights_alias.wkv_b_weights],
state_dict[self.tilert_weights_alias.wkv_b_scales],
],
)
+ self.tilert_wkv_b_a = packed
+ self.tilert_wkv_b_a_scales = torch.empty(1, dtype=torch.float8_e4m3fn, device=packed.device)
+ self.compute_kernel_type = "fp16mma"
+
+ def init_tilert_weights_hmma_bf16(self, state_dict: dict[str, torch.Tensor]) -> None:
+ """Initialize with BF16 HMMA-packed weights (dequantized, no scales)."""
+ packed = ProjqWqbWeightsConverter(
+ self.model_args, self.num_devices, self.head_dim_block_size
+ ).dispatch(
+ ProjqWqbAlgorithm.BF16MMA,
+ [
+ state_dict[self.tilert_weights_alias.wkv_b_weights],
+ state_dict[self.tilert_weights_alias.wkv_b_scales],
+ ],
+ )
+ self.tilert_wkv_b_a = packed
+ self.tilert_wkv_b_a_scales = torch.empty(1, dtype=torch.float8_e4m3fn, device=packed.device)
+ self.compute_kernel_type = "bf16mma"
def init_random_weights(self) -> None:
+ padded_total_heads = self.num_local_heads * self.num_devices
wkv_b = init_func(
torch.empty(
- self.model_args.n_heads * self.wkvb_head_dim,
+ padded_total_heads * self.wkvb_head_dim,
self.wkvb_lora_rank,
dtype=torch.float8_e4m3fn,
)
)
wkv_b_scales = init_func(
torch.empty(
- # Block quant should be applied to the original weight dimension (including head
- # dimension)
- self.model_args.n_heads * self.wkvb_head_dim // self.model_args.block_size,
+ padded_total_heads * self.wkvb_head_dim // self.model_args.block_size,
self.wkvb_lora_rank_qsize,
dtype=torch.float32,
)
@@ -287,9 +460,7 @@ def tilert_forward(self, q_nope: torch.Tensor) -> torch.Tensor:
self.tilert_wkv_b_a_scales,
self.output,
self.profile_logs,
+ self.compute_kernel_type,
+ model_arch=self.model_args.arch_name,
)
- if self.flag_enable_profiling_log:
- parse_profile_log_tensor(
- self.profile_logs, self.get_profile_log_path(), [(self.op_name, 0.0)]
- )
return self.output
diff --git a/python/models/deepseek_v3_2/ops/projx_wis.py b/tilert/models/deepseek_v3_2/ops/projx_wis.py
similarity index 62%
rename from python/models/deepseek_v3_2/ops/projx_wis.py
rename to tilert/models/deepseek_v3_2/ops/projx_wis.py
index e264659..ebd7ff3 100644
--- a/python/models/deepseek_v3_2/ops/projx_wis.py
+++ b/tilert/models/deepseek_v3_2/ops/projx_wis.py
@@ -1,13 +1,13 @@
"""ProjxWis operation module."""
from dataclasses import dataclass
+from enum import Enum
import torch
from tilert.models.base import TileRTModule
from tilert.models.common import init_func
from tilert.models.deepseek_v3_2.model_args import ModelArgs
-from tilert.profiler.utils import parse_profile_log_tensor
from tilert.utils import get_profile_log_tensor
__all__ = [
@@ -22,7 +22,9 @@ def projx_wis(
x_in: torch.Tensor,
w: torch.Tensor,
output: torch.Tensor,
+ compute_kernel_type: str,
profile_logs: torch.Tensor,
+ model_arch: str,
) -> None:
"""
Define the ProjxWis operation.
@@ -31,12 +33,11 @@ def projx_wis(
x_in: Input tensor.
w: Weight tensor.
output: Output tensor.
+ compute_kernel_type: Compute kernel type ("bf16" or "bf16mma").
profile_logs: Profile logs tensor.
+ model_arch: Model architecture ("deepseek_v3_2" or "glm_5").
"""
- if x_in.shape[-1] == 7168:
- torch.ops.tilert.proj_w_op(x_in, w, output, profile_logs)
- elif x_in.shape[-1] == 6144:
- torch.ops.tilert.proj_w_glm5_op(x_in, w, output, profile_logs)
+ torch.ops.tilert.proj_w_op(x_in, w, output, model_arch, compute_kernel_type, profile_logs)
@dataclass
@@ -67,15 +68,33 @@ def __call__(self) -> list[str]:
return self.tilert_tensor_alias
+class ProjxWisAlgorithm(Enum):
+ """ProjxWis algorithm."""
+
+ BF16 = "bf16"
+ BF16MMA = "bf16mma"
+
+
class ProjxWis(TileRTModule):
"""ProjxWis module: linear projection for indexer score weights."""
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [ProjxWisAlgorithm.BF16, ProjxWisAlgorithm.BF16MMA],
+ "glm_5": [ProjxWisAlgorithm.BF16, ProjxWisAlgorithm.BF16MMA],
+ }
+
+ _HMMA_CONFIGS = {
+ 7168: (4, 16, 7),
+ 6144: (2, 16, 6),
+ }
+
def __init__(
self,
model_args: ModelArgs,
num_devices: int,
device_id: int = 0,
ref_weights_alias: ProjxWisRefWeightsAlias | None = None,
+ compute_kernel_type: str | None = None,
):
super().__init__(
self.__class__.__name__,
@@ -89,7 +108,6 @@ def __init__(
ref_weights_alias if ref_weights_alias is not None else ProjxWisRefWeightsAlias()
)
- # Backward compatibility: expose list for load_weights_for_layer etc.
self.ref_tensor_alias = self.ref_weights_alias.ref_tensor_alias
self.ref_w: torch.Tensor | None = None
@@ -100,6 +118,33 @@ def __init__(
self.dim = model_args.dim
self.index_n_heads = model_args.index_n_heads
+ if compute_kernel_type is not None:
+ self.compute_kernel_type = compute_kernel_type
+ else:
+ self.compute_kernel_type = "bf16"
+
+ @staticmethod
+ def _swizzle_mma_16x16(mat_in: torch.Tensor) -> torch.Tensor:
+ """Swizzle a 16x16 BF16 tile for the MMA kernel."""
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 16
+ pre_shape = mat_in.shape[:-2]
+ mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 2).transpose(-4, -3).transpose(-5, -4)
+ return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 2).transpose(-3, -2)
+
+ @staticmethod
+ def _to_hmma_layout(
+ w_orig: torch.Tensor, n_ctas: int, rows_per_cta: int, x_dim: int, num_pages: int
+ ) -> torch.Tensor:
+ """Convert [output_dim, x_dim] BF16 weights to the MMA layout."""
+ cols_per_page = x_dim // num_pages
+ n_k_tiles = cols_per_page // 16
+ w = w_orig.reshape(n_ctas, rows_per_cta, num_pages, cols_per_page)
+ w = w.transpose(1, 2)
+ n_row_tiles = rows_per_cta // 16
+ w = w.reshape(n_ctas, num_pages, n_row_tiles, 16, n_k_tiles, 16).transpose(-3, -2)
+ w = ProjxWis._swizzle_mma_16x16(w)
+ return w.reshape(n_ctas, -1).contiguous()
+
@property
def tilert_tensor_alias(self) -> list[str]:
return self.tilert_weights_alias.tilert_tensor_alias
@@ -117,8 +162,14 @@ def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, tor
Returns:
Map from tilert weight alias to (num_devices, ...) tensors.
"""
- w = weights_map[self.ref_weights_alias.w_weights][None, ...].repeat(self.num_devices, 1, 1)
- return {self.tilert_weights_alias.w_weights: w}
+ w = weights_map[self.ref_weights_alias.w_weights]
+ if self.compute_kernel_type == "bf16mma":
+ n_ctas, rows_per_cta, num_pages = self._HMMA_CONFIGS[self.dim]
+ w_hmma = self._to_hmma_layout(w, n_ctas, rows_per_cta, self.dim, num_pages)
+ w_out = w_hmma[None, ...].repeat(self.num_devices, 1, 1)
+ else:
+ w_out = w[None, ...].repeat(self.num_devices, 1, 1)
+ return {self.tilert_weights_alias.w_weights: w_out}
def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
w = state_dict[self.ref_weights_alias.w_weights]
@@ -149,9 +200,12 @@ def tilert_forward(self, x_norm: torch.Tensor) -> torch.Tensor:
assert self.tilert_w is not None
assert self.output is not None
assert self.profile_logs is not None
- projx_wis(x_norm, self.tilert_w, self.output, self.profile_logs)
- if self.flag_enable_profiling_log:
- parse_profile_log_tensor(
- self.profile_logs, self.get_profile_log_path(), [(self.op_name, 0.0)]
- )
+ projx_wis(
+ x_norm,
+ self.tilert_w,
+ self.output,
+ self.compute_kernel_type,
+ self.profile_logs,
+ model_arch=self.model_args.arch_name,
+ )
return self.output
diff --git a/tilert/models/deepseek_v3_2/ops/projx_wqaki.py b/tilert/models/deepseek_v3_2/ops/projx_wqaki.py
new file mode 100644
index 0000000..9bc90b5
--- /dev/null
+++ b/tilert/models/deepseek_v3_2/ops/projx_wqaki.py
@@ -0,0 +1,247 @@
+"""ProjxWqaki operation module."""
+
+import torch
+
+__all__ = [
+ "projx_wqaki",
+ "ProjxWqakiWeightsConverter",
+]
+
+
+def projx_wqaki(
+ x_quant: torch.Tensor,
+ x_scale: torch.Tensor,
+ wqaki: torch.Tensor,
+ out_q: torch.Tensor,
+ out_ki: torch.Tensor,
+ profile_logs: torch.Tensor,
+ compute_kernel_type: str = "fp8mma",
+ *,
+ model_arch: str,
+) -> None:
+ """FP8 MMA projection for q, ki.
+
+ Args:
+ x_quant: FP8 quantized hidden states [1, seq_len, hidden_dim].
+ x_scale: Scale factors for x_quant.
+ wqaki: Packed FP8 weights + scales for q, ki.
+ out_q: Output q tensor.
+ out_ki: Output ki tensor.
+ profile_logs: Profile logs tensor.
+ compute_kernel_type: Kernel type ("fp8mma", "fp8mma_68cta", "fp8mma_136cta").
+ model_arch: Model architecture ("deepseek_v3_2" or "glm_5").
+ """
+ torch.ops.tilert.projx_wqaki_op(
+ x_quant,
+ x_scale,
+ wqaki,
+ out_q,
+ out_ki,
+ model_arch,
+ compute_kernel_type,
+ profile_logs,
+ torch.empty(0, dtype=torch.int64, device=x_quant.device),
+ )
+
+
+class ProjxWqakiWeightsConverter:
+ """Weight converter for ProjxWqaki kernel."""
+
+ @staticmethod
+ def _swizzle_qmma_16x32(mat_in: torch.Tensor) -> torch.Tensor:
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 32
+ assert mat_in.dtype == torch.float8_e4m3fn
+ pre_shape = mat_in.shape[:-2]
+ mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 4).transpose(-4, -3).transpose(-5, -4)
+ return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 4).transpose(-3, -2)
+
+ @staticmethod
+ def convert_dsv32(
+ wq_a: torch.Tensor,
+ wq_a_scale: torch.Tensor,
+ wki: torch.Tensor,
+ wki_scale: torch.Tensor,
+ ) -> torch.Tensor:
+ """Convert DSV3.2 weights to the FP8 MMA page layout."""
+ with torch.inference_mode():
+ wq_a_scale = wq_a_scale.to(torch.bfloat16)
+ wki_scale = wki_scale.to(torch.bfloat16)
+
+ dim = 7168
+ q_rows = 1536
+ ki_rows = 128
+ total_rows = q_rows + ki_rows
+ n_blocks = total_rows // 16
+ scale_dim = dim // 128
+
+ n_q_blocks = q_rows // 16
+ n_ki_blocks = ki_rows // 16
+ wq_a = wq_a.reshape(n_q_blocks, 16, dim)
+ wq_a_scale = (
+ wq_a_scale.reshape(wq_a_scale.shape[0], 1, scale_dim)
+ .repeat(1, n_q_blocks // wq_a_scale.shape[0], 1)
+ .reshape(n_q_blocks, scale_dim)
+ )
+ wki = wki.reshape(n_ki_blocks, 16, dim)
+ wki_scale = (
+ wki_scale.reshape(wki_scale.shape[0], 1, scale_dim)
+ .repeat(1, n_ki_blocks // wki_scale.shape[0], 1)
+ .reshape(n_ki_blocks, scale_dim)
+ )
+
+ wqaki = torch.cat([wq_a, wki], dim=0)
+ wqaki_scale = torch.cat([wq_a_scale, wki_scale], dim=0)
+
+ swizzle = ProjxWqakiWeightsConverter._swizzle_qmma_16x32
+
+ wqaki_0 = wqaki[..., :2048]
+ wqaki_0_scale = wqaki_scale[..., :16].contiguous().view(torch.float8_e4m3fn)
+ wqaki_1 = wqaki[..., 2048:4096]
+ wqaki_1_scale = wqaki_scale[..., 16:32].contiguous().view(torch.float8_e4m3fn)
+ wqaki_2 = wqaki[..., 4096:6144]
+ wqaki_2_scale = wqaki_scale[..., 32:48].contiguous().view(torch.float8_e4m3fn)
+ wqaki_3 = wqaki[..., 6144:7168]
+ wqaki_3_scale = wqaki_scale[..., 48:56].contiguous().view(torch.float8_e4m3fn)
+
+ wqaki_0 = wqaki_0.reshape(n_blocks, 16, 64, 32).transpose(1, 2)
+ wqaki_0 = swizzle(wqaki_0).reshape(n_blocks, 16 * 2048)
+
+ wqaki_1 = wqaki_1.reshape(n_blocks, 16, 64, 32).transpose(1, 2)
+ wqaki_1 = swizzle(wqaki_1).reshape(n_blocks, 16 * 2048)
+
+ wqaki_2 = wqaki_2.reshape(n_blocks, 16, 64, 32).transpose(1, 2)
+ wqaki_2 = swizzle(wqaki_2).reshape(n_blocks, 16 * 2048)
+
+ wqaki_3 = wqaki_3.reshape(n_blocks, 16, 32, 32).transpose(1, 2)
+ wqaki_3 = swizzle(wqaki_3).reshape(n_blocks, 16 * 1024)
+
+ padding_scale0 = torch.zeros(
+ (n_blocks, 48), dtype=torch.bfloat16, device=wq_a.device
+ ).view(torch.float8_e4m3fn)
+ padding_scale1 = torch.zeros(
+ (n_blocks, 48), dtype=torch.bfloat16, device=wq_a.device
+ ).view(torch.float8_e4m3fn)
+ padding_scale2 = torch.zeros(
+ (n_blocks, 48), dtype=torch.bfloat16, device=wq_a.device
+ ).view(torch.float8_e4m3fn)
+ padding_scale3 = torch.zeros(
+ (n_blocks, 56), dtype=torch.bfloat16, device=wq_a.device
+ ).view(torch.float8_e4m3fn)
+
+ return torch.cat(
+ [
+ wqaki_0,
+ wqaki_0_scale,
+ padding_scale0,
+ wqaki_1,
+ wqaki_1_scale,
+ padding_scale1,
+ wqaki_2,
+ wqaki_2_scale,
+ padding_scale2,
+ wqaki_3,
+ wqaki_3_scale,
+ padding_scale3,
+ ],
+ dim=1,
+ ).contiguous()
+
+ @staticmethod
+ def convert_glm5_68cta(
+ wq_a: torch.Tensor,
+ wq_a_scale: torch.Tensor,
+ wki: torch.Tensor,
+ wki_scale: torch.Tensor,
+ ) -> torch.Tensor:
+ """Convert GLM5 weights to the FP8 MMA page layout (68CTA)."""
+ with torch.inference_mode():
+ wq_a_scale = wq_a_scale.to(torch.float32)
+ wki_scale = wki_scale.to(torch.float32)
+
+ dim = 6144
+ q_rows = 2048
+ ki_rows = 128
+ total_rows = q_rows + ki_rows
+ n_blocks = total_rows // 32
+ scale_dim = dim // 128
+
+ n_q_blocks = q_rows // 32
+ n_ki_blocks = ki_rows // 32
+
+ wqaki_raw = torch.cat([wq_a, wki], dim=0).reshape(n_blocks, 32, dim)
+
+ wq_a_scale = (
+ wq_a_scale.reshape(wq_a_scale.shape[0], 1, scale_dim)
+ .repeat(1, n_q_blocks // wq_a_scale.shape[0], 1)
+ .reshape(n_q_blocks, scale_dim)
+ )
+ wki_scale = (
+ wki_scale.reshape(wki_scale.shape[0], 1, scale_dim)
+ .repeat(1, n_ki_blocks // wki_scale.shape[0], 1)
+ .reshape(n_ki_blocks, scale_dim)
+ )
+ wqaki_scales = torch.cat([wq_a_scale, wki_scale], dim=0)
+
+ swizzle = ProjxWqakiWeightsConverter._swizzle_qmma_16x32
+
+ wqaki_raw = wqaki_raw.reshape(n_blocks, 32, 6, 1024).transpose(1, 2)
+ wqaki_raw = wqaki_raw.reshape(n_blocks, 6, 2, 16, 32, 32).transpose(3, 4)
+ wqaki_raw = swizzle(wqaki_raw).reshape(n_blocks, 6, 32 * 1024)
+ wqaki_scales = wqaki_scales.reshape(n_blocks, 6, 8).view(torch.float8_e4m3fn)
+ wqaki_padding = torch.zeros(
+ (n_blocks, 6, 128 - wqaki_scales.shape[-1]),
+ dtype=torch.float8_e4m3fn,
+ device=wq_a.device,
+ )
+ return torch.cat([wqaki_raw, wqaki_scales, wqaki_padding], dim=-1).contiguous()
+
+ @staticmethod
+ def convert_glm5_136cta(
+ wq_a: torch.Tensor,
+ wq_a_scale: torch.Tensor,
+ wki: torch.Tensor,
+ wki_scale: torch.Tensor,
+ ) -> torch.Tensor:
+ """Convert GLM5 weights to the FP8 MMA page layout (136CTA)."""
+ with torch.inference_mode():
+ wq_a_scale = wq_a_scale.to(torch.float32)
+ wki_scale = wki_scale.to(torch.float32)
+
+ dim = 6144
+ q_rows = 2048
+ ki_rows = 128
+ total_rows = q_rows + ki_rows
+ n_blocks = total_rows // 16
+ scale_dim = dim // 128
+
+ n_q_blocks = q_rows // 16
+ n_ki_blocks = ki_rows // 16
+
+ wq_a = wq_a.reshape(n_q_blocks, 16, dim)
+ wq_a_scale = (
+ wq_a_scale.reshape(wq_a_scale.shape[0], 1, scale_dim)
+ .repeat(1, n_q_blocks // wq_a_scale.shape[0], 1)
+ .reshape(n_q_blocks, scale_dim)
+ )
+ wki = wki.reshape(n_ki_blocks, 16, dim)
+ wki_scale = (
+ wki_scale.reshape(wki_scale.shape[0], 1, scale_dim)
+ .repeat(1, n_ki_blocks // wki_scale.shape[0], 1)
+ .reshape(n_ki_blocks, scale_dim)
+ )
+
+ wqaki_raw = torch.cat([wq_a, wki], dim=0)
+ wqaki_scales = torch.cat([wq_a_scale, wki_scale], dim=0)
+
+ swizzle = ProjxWqakiWeightsConverter._swizzle_qmma_16x32
+
+ wqaki_raw = wqaki_raw.reshape(n_blocks, 16, 3, 2048).transpose(1, 2)
+ wqaki_raw = wqaki_raw.reshape(n_blocks, 3, 1, 16, 64, 32).transpose(3, 4)
+ wqaki_raw = swizzle(wqaki_raw).reshape(n_blocks, 3, 16 * 2048)
+ wqaki_scales = wqaki_scales.reshape(n_blocks, 3, 16).view(torch.float8_e4m3fn)
+ wqaki_padding = torch.zeros(
+ (n_blocks, 3, 128 - wqaki_scales.shape[-1]),
+ dtype=torch.float8_e4m3fn,
+ device=wq_a.device,
+ )
+ return torch.cat([wqaki_raw, wqaki_scales, wqaki_padding], dim=-1).contiguous()
diff --git a/tilert/models/deepseek_v3_2/ops/projx_wqkva.py b/tilert/models/deepseek_v3_2/ops/projx_wqkva.py
new file mode 100644
index 0000000..0d36af8
--- /dev/null
+++ b/tilert/models/deepseek_v3_2/ops/projx_wqkva.py
@@ -0,0 +1,329 @@
+"""ProjXWqkva operation module."""
+
+from enum import Enum
+
+import torch
+
+from tilert.models.base import TileRTModule
+from tilert.models.common import weight_dequant
+from tilert.models.deepseek_v3_2.model_args import ModelArgs
+from tilert.models.deepseek_v3_2.ops.rmsnorm_projx_wqkva import (
+ RMSNormProjQKVAFP8MMAWeightsConverter,
+ RMSNormProjQKVAFP16MMAWeightsConverter,
+)
+from tilert.utils import get_profile_log_tensor
+
+__all__ = [
+ "ProjXWqkva",
+ "projx_wqkva",
+]
+
+
+def projx_wqkva(
+ x_quant: torch.Tensor,
+ x_scale: torch.Tensor,
+ wqkva: torch.Tensor,
+ cur_pos: torch.Tensor,
+ q_out: torch.Tensor,
+ kv_out: torch.Tensor,
+ pe_cache_out: torch.Tensor,
+ profile_logs: torch.Tensor,
+ compute_kernel_type: str = "fp8mma",
+ *,
+ model_arch: str,
+) -> None:
+ """FP8 MMA projection for q, kv, pe_cache (DSV3.2)."""
+ torch.ops.tilert.projx_wqkva_op(
+ x_quant,
+ x_scale,
+ wqkva,
+ cur_pos,
+ q_out,
+ kv_out,
+ pe_cache_out,
+ model_arch,
+ compute_kernel_type,
+ profile_logs,
+ torch.empty(0, dtype=torch.int64, device=x_quant.device),
+ )
+
+
+class ProjXWqkvaRefWeightsAlias:
+ """Reference weight aliases for ProjXWqkva."""
+
+ x_rmsnorm_gamma = "input_layernorm.weight"
+ q_a_weights = "self_attn.q_a_proj.weight"
+ q_a_scales = "self_attn.q_a_proj.weight_scale_inv"
+ kv_a_weights = "self_attn.kv_a_proj_with_mqa.weight"
+ kv_a_scales = "self_attn.kv_a_proj_with_mqa.weight_scale_inv"
+
+ @property
+ def ref_tensor_alias(self) -> list[str]:
+ return [
+ self.x_rmsnorm_gamma,
+ self.q_a_weights,
+ self.q_a_scales,
+ self.kv_a_weights,
+ self.kv_a_scales,
+ ]
+
+ def __call__(self) -> list[str]:
+ return self.ref_tensor_alias
+
+
+class ProjXWqkvaTilertWeightsAlias:
+ """Tilert weight aliases for ProjXWqkva."""
+
+ q_a_weights = "q_a_weights"
+ q_a_scales = "q_a_scales"
+ kv_a_weights = "kv_a_weights"
+ kv_a_scales = "kv_a_scales"
+ w_pe_weights = "w_pe_weights"
+ w_pe_scales = "w_pe_scales"
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return [
+ self.q_a_weights,
+ self.q_a_scales,
+ self.kv_a_weights,
+ self.kv_a_scales,
+ self.w_pe_weights,
+ self.w_pe_scales,
+ ]
+
+ def __call__(self) -> list[str]:
+ return self.tilert_tensor_alias
+
+
+class ProjXWqkvaAlgorithm(Enum):
+ """ProjXWqkva algorithm."""
+
+ FP8MMA = "fp8mma"
+ FP16MMA = "fp16mma"
+
+
+class ProjXWqkva(TileRTModule):
+ """FP8 MMA projection module for q, kv, pe_cache."""
+
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [ProjXWqkvaAlgorithm.FP8MMA],
+ }
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ num_devices: int,
+ device_id: int,
+ ref_weights_alias: ProjXWqkvaRefWeightsAlias | None = None,
+ ):
+ super().__init__(
+ self.__class__.__name__,
+ model_args=model_args,
+ num_devices=num_devices,
+ device_id=device_id,
+ )
+
+ self.tilert_weights_alias = ProjXWqkvaTilertWeightsAlias()
+ self.ref_weights_alias = (
+ ref_weights_alias if ref_weights_alias is not None else ProjXWqkvaRefWeightsAlias()
+ )
+
+ self.dim = self.model_args.dim
+ self.q_lora_rank = self.model_args.q_lora_rank
+ self.kv_lora_rank = self.model_args.kv_lora_rank
+ self.qk_rope_head_dim = self.model_args.qk_rope_head_dim
+ self.block_size = self.model_args.block_size
+ self.eps = self.model_args.eps
+
+ self.ref_wq_a: torch.Tensor | None = None
+ self.ref_wkv_a: torch.Tensor | None = None
+ self.ref_w_pe: torch.Tensor | None = None
+
+ self.tilert_wqkva: torch.Tensor | None = None
+
+ self.q_out: torch.Tensor | None = None
+ self.kv_out: torch.Tensor | None = None
+ self.pe_cache_out: torch.Tensor | None = None
+ self.cur_pos: torch.Tensor | None = None
+ self.profile_logs: torch.Tensor | None = None
+ self.is_init = False
+
+ self.compute_kernel_type = "fp8mma"
+
+ def set_algorithm(self, algorithm: Enum) -> None:
+ super().set_algorithm(algorithm)
+ if algorithm == ProjXWqkvaAlgorithm.FP16MMA:
+ self.compute_kernel_type = "fp16mma"
+ else:
+ self.compute_kernel_type = "fp8mma"
+
+ def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ """Repeat weights for device sharding."""
+ q_a_proj_weight = weights_map[self.ref_weights_alias.q_a_weights][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ q_a_proj_weight_scale = weights_map[self.ref_weights_alias.q_a_scales][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ kv_a_mqa = weights_map[self.ref_weights_alias.kv_a_weights]
+ kv_a_proj_weight = kv_a_mqa[: self.kv_lora_rank, :][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ w_pe_weight = kv_a_mqa[self.kv_lora_rank :, :][None, ...].repeat(self.num_devices, 1, 1)
+ kv_a_mqa_scale = weights_map[self.ref_weights_alias.kv_a_scales]
+ kv_scale_rows = (self.kv_lora_rank + self.block_size - 1) // self.block_size
+ kv_a_proj_weight_scale = kv_a_mqa_scale[:kv_scale_rows, :][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ w_pe_weight_scale = kv_a_mqa_scale[kv_scale_rows:, :][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ return {
+ self.tilert_weights_alias.q_a_weights: q_a_proj_weight,
+ self.tilert_weights_alias.q_a_scales: q_a_proj_weight_scale,
+ self.tilert_weights_alias.kv_a_weights: kv_a_proj_weight,
+ self.tilert_weights_alias.kv_a_scales: kv_a_proj_weight_scale,
+ self.tilert_weights_alias.w_pe_weights: w_pe_weight,
+ self.tilert_weights_alias.w_pe_scales: w_pe_weight_scale,
+ }
+
+ def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ aliases = self.ref_weights_alias()
+ self.ref_wq_a = weight_dequant(state_dict[aliases[1]], state_dict[aliases[2]])
+ kv_a_mqa = weight_dequant(state_dict[aliases[3]], state_dict[aliases[4]])
+ self.ref_wkv_a = kv_a_mqa[: self.kv_lora_rank, :]
+ self.ref_w_pe = kv_a_mqa[self.kv_lora_rank :, :]
+
+ assert self.ref_wq_a.shape == (self.q_lora_rank, self.dim)
+ assert self.ref_wkv_a.shape == (self.kv_lora_rank, self.dim)
+ assert self.ref_w_pe.shape == (self.qk_rope_head_dim, self.dim)
+
+ def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ tilert_aliases = self.tilert_weights_alias()
+ wq_a = state_dict[tilert_aliases[0]]
+ wq_a_scale = state_dict[tilert_aliases[1]]
+ wkv_a = state_dict[tilert_aliases[2]]
+ wkv_a_scale = state_dict[tilert_aliases[3]]
+ w_pe = state_dict[tilert_aliases[4]]
+ w_pe_scale = state_dict[tilert_aliases[5]]
+ dummy_gamma = torch.zeros(self.dim, dtype=torch.float32, device=wq_a.device)
+
+ if self.algorithm == ProjXWqkvaAlgorithm.FP16MMA:
+ self.tilert_wqkva, _ = RMSNormProjQKVAFP16MMAWeightsConverter.convert_to_fp16_mma_gemv(
+ wq_a,
+ wq_a_scale,
+ wkv_a,
+ wkv_a_scale,
+ w_pe,
+ w_pe_scale,
+ dummy_gamma,
+ hidden_dim=self.dim,
+ q_lora_rank=self.q_lora_rank,
+ )
+ else:
+ self.tilert_wqkva, _ = RMSNormProjQKVAFP8MMAWeightsConverter.convert_to_fp8_mma_gemv(
+ wq_a,
+ wq_a_scale,
+ wkv_a,
+ wkv_a_scale,
+ w_pe,
+ w_pe_scale,
+ dummy_gamma,
+ hidden_dim=self.dim,
+ q_lora_rank=self.q_lora_rank,
+ )
+
+ def init_tilert_vars(self, batch_size: int, seq_len: int, max_len: int = 128) -> None:
+ self.q_out = torch.zeros((batch_size, seq_len, self.q_lora_rank), dtype=torch.bfloat16)
+ self.kv_out = torch.zeros((batch_size, seq_len, self.kv_lora_rank), dtype=torch.bfloat16)
+ self.pe_cache_out = torch.zeros(
+ (batch_size, max_len, self.qk_rope_head_dim), dtype=torch.bfloat16
+ )
+ self.cur_pos = torch.zeros((1,), dtype=torch.int32)
+ self.profile_logs = get_profile_log_tensor()
+ self.is_init = True
+
+ def init_random_weights(self) -> None:
+ bs = self.block_size
+ dim_scale_dim = self.dim // bs
+ q_scale_dim = (self.q_lora_rank + bs - 1) // bs
+ kv_mqa_rows = self.kv_lora_rank + self.qk_rope_head_dim
+ kv_mqa_scale_dim = (kv_mqa_rows + bs - 1) // bs
+ scale_dtype = torch.bfloat16
+
+ tensor_list = [
+ torch.randn(self.dim, dtype=torch.float32),
+ torch.randn(self.q_lora_rank, self.dim, dtype=torch.bfloat16).to(torch.float8_e4m3fn),
+ torch.randn(q_scale_dim, dim_scale_dim, dtype=scale_dtype),
+ torch.randn(kv_mqa_rows, self.dim, dtype=torch.bfloat16).to(torch.float8_e4m3fn),
+ torch.randn(kv_mqa_scale_dim, dim_scale_dim, dtype=scale_dtype),
+ ]
+ ref_state_dict = dict(zip(self.ref_weights_alias(), tensor_list))
+ self.init_reference_weights(ref_state_dict)
+ self.init_tilert_weights(
+ {_k: _v[self.device_id] for _k, _v in self.device_sharding(ref_state_dict).items()}
+ )
+
+ def golden_forward(
+ self,
+ x_quant: torch.Tensor,
+ x_scale: torch.Tensor,
+ cur_pos: int = 0, # noqa: U100
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Pure PyTorch reference: dequant FP8 -> matmul -> q, kv, pe."""
+ assert self.ref_wq_a is not None
+ assert self.ref_wkv_a is not None
+ assert self.ref_w_pe is not None
+
+ if self.algorithm == ProjXWqkvaAlgorithm.FP16MMA:
+ x_float = x_quant.float()
+ else:
+ x_fp8 = x_quant.to(torch.float32)
+ scale_expanded = x_scale.unsqueeze(-1).repeat(1, 1, 1, self.block_size)
+ scale_expanded = scale_expanded.reshape(x_quant.shape)
+ x_float = x_fp8 * scale_expanded
+
+ q_out = torch.matmul(x_float, self.ref_wq_a.transpose(0, 1).float())
+ kv_out = torch.matmul(x_float, self.ref_wkv_a.transpose(0, 1).float())
+ pe_out = torch.matmul(x_float, self.ref_w_pe.transpose(0, 1).float())
+ return (
+ q_out.to(torch.bfloat16),
+ kv_out.to(torch.bfloat16),
+ pe_out.to(torch.bfloat16),
+ )
+
+ def tilert_forward(
+ self,
+ x_quant: torch.Tensor,
+ x_scale: torch.Tensor,
+ cur_pos: int = 0,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Run FP8 QMMA GEMV via TileRT CUDA kernel."""
+ assert self.cur_pos is not None
+ assert self.pe_cache_out is not None
+ self.cur_pos.fill_(cur_pos)
+ projx_wqkva(
+ x_quant,
+ x_scale,
+ self.tilert_wqkva,
+ self.cur_pos,
+ self.q_out,
+ self.kv_out,
+ self.pe_cache_out,
+ self.profile_logs,
+ self.compute_kernel_type,
+ model_arch=self.model_args.arch_name,
+ )
+
+ seq_len = x_quant.size(-2)
+ pe_at_pos = self.pe_cache_out[:, cur_pos : cur_pos + seq_len, :]
+ return self.q_out, self.kv_out, pe_at_pos
+
+ def __call__(
+ self,
+ x_quant: torch.Tensor,
+ x_scale: torch.Tensor,
+ cur_pos: int = 0,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ return self.golden_forward(x_quant, x_scale, cur_pos)
diff --git a/python/models/deepseek_v3_2/ops/qkv_rope.py b/tilert/models/deepseek_v3_2/ops/qkv_rope.py
similarity index 77%
rename from python/models/deepseek_v3_2/ops/qkv_rope.py
rename to tilert/models/deepseek_v3_2/ops/qkv_rope.py
index 7a9a55a..25203a4 100644
--- a/python/models/deepseek_v3_2/ops/qkv_rope.py
+++ b/tilert/models/deepseek_v3_2/ops/qkv_rope.py
@@ -1,17 +1,13 @@
-"""QKV Rope operation module.
-
-Unified for deepseek_v3_2 (n_local_heads=16) and glm_5 (n_local_heads=8).
-Dispatches by q_pe.shape[2]: 16 -> qkv_rope_op, 8 -> qkv_rope_glm5_op.
-"""
+"""QKV Rope operation module."""
from dataclasses import dataclass
+from enum import Enum
import torch
from tilert.models.base import TileRTModule
from tilert.models.deepseek_v3_2.model_args import ModelArgs
from tilert.models.utils import apply_rotary_emb
-from tilert.profiler.utils import parse_profile_log_tensor
from tilert.utils import get_profile_log_tensor
__all__ = [
@@ -28,34 +24,30 @@ def qkv_rope(
rope_freqs: torch.Tensor,
cur_pos: torch.Tensor,
profile_logs: torch.Tensor,
+ model_arch: str,
+ compute_kernel_type: str = "general",
) -> None:
"""
Perform QKV Rope operation.
- Unified for deepseek_v3_2 (16 heads) and glm_5 (8 heads). Dispatches by
- pe_cache (q_pe) shape[2]: 16 -> qkv_rope_op, 8 -> qkv_rope_glm5_op.
-
Args:
pe_cache: Q PE tensor (bsz, seq, n_local_heads, qk_rope_head_dim).
kv_cache: K PE cache (bsz, seq, qk_rope_head_dim).
rope_freqs: Rope frequencies tensor.
cur_pos: Current position tensor.
profile_logs: Profile logs tensor.
+ model_arch: Model architecture string.
+ compute_kernel_type: Compute kernel type string.
"""
- n_local_heads = pe_cache.shape[2]
- qk_rope_head_dim = pe_cache.shape[3]
- if qk_rope_head_dim != 64:
- raise ValueError(f"Unsupported qk_rope_head_dim: {qk_rope_head_dim}")
-
- if n_local_heads == 16:
- torch.ops.tilert.qkv_rope_op(pe_cache, kv_cache, rope_freqs, cur_pos, profile_logs)
- elif n_local_heads == 8:
- torch.ops.tilert.qkv_rope_glm5_op(pe_cache, kv_cache, rope_freqs, cur_pos, profile_logs)
- else:
- raise ValueError(
- f"Unsupported n_local_heads: {n_local_heads}. "
- "QKVRoPE supports n_local_heads=16 (deepseek_v3_2) or 8 (glm_5)."
- )
+ torch.ops.tilert.qkv_rope_op(
+ pe_cache,
+ kv_cache,
+ rope_freqs,
+ cur_pos,
+ model_arch,
+ compute_kernel_type,
+ profile_logs,
+ )
@dataclass
@@ -82,9 +74,20 @@ def __call__(self) -> list[str]:
return self.tilert_tensor_alias
+class QKVRoPEAlgorithm(Enum):
+ """QKVRoPE algorithm."""
+
+ GENERAL = "general"
+
+
class QKVRoPE(TileRTModule):
"""QKV RoPE module. Unified for deepseek_v3_2 and glm_5."""
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [QKVRoPEAlgorithm.GENERAL],
+ "glm_5": [QKVRoPEAlgorithm.GENERAL],
+ }
+
def __init__(
self,
model_args: ModelArgs,
@@ -165,12 +168,13 @@ def tilert_forward(
cur_pos = torch.tensor([start_pos], dtype=torch.int32)
qkv_rope(
- q_pe_rope, pe_cache[:bsz, start_pos:end_pos], rope_freqs, cur_pos, self.profile_logs
+ q_pe_rope,
+ pe_cache[:bsz, start_pos:end_pos],
+ rope_freqs,
+ cur_pos,
+ self.profile_logs,
+ model_arch=self.model_args.arch_name,
)
- if self.flag_enable_profiling_log:
- parse_profile_log_tensor(
- self.profile_logs, self.get_profile_log_path(), [(self.op_name, 0.0)]
- )
return q_pe_rope
diff --git a/tilert/models/deepseek_v3_2/ops/receive_selected_token_ids.py b/tilert/models/deepseek_v3_2/ops/receive_selected_token_ids.py
new file mode 100644
index 0000000..508d13e
--- /dev/null
+++ b/tilert/models/deepseek_v3_2/ops/receive_selected_token_ids.py
@@ -0,0 +1,35 @@
+"""ReceiveSelectedTokenIds — receive idx_selects from GPU 0."""
+
+import torch
+
+__all__ = [
+ "receive_selected_token_ids",
+]
+
+
+def receive_selected_token_ids(
+ ll_buf: torch.Tensor,
+ dst: torch.Tensor,
+ expected_flag: int,
+ profile_logs: torch.Tensor,
+ model_arch: str,
+ compute_kernel_type: str = "bf16",
+) -> None:
+ """Receive idx_selects from GPU 0.
+
+ Args:
+ ll_buf: Receive buffer on this GPU (written by GPU 0).
+ dst: Destination idx_selects tensor [1, S, 2048] int32.
+ expected_flag: Expected synchronization flag value.
+ profile_logs: Profile logs tensor.
+ model_arch: Model architecture ("deepseek_v3_2" or "glm_5").
+ compute_kernel_type: Compute kernel type ("bf16").
+ """
+ torch.ops.tilert.receive_selected_token_ids_op(
+ ll_buf,
+ dst,
+ expected_flag,
+ model_arch,
+ compute_kernel_type,
+ profile_logs,
+ )
diff --git a/python/models/deepseek_v3_2/ops/rmsnorm_expert_proj.py b/tilert/models/deepseek_v3_2/ops/rmsnorm_expert_proj.py
similarity index 94%
rename from python/models/deepseek_v3_2/ops/rmsnorm_expert_proj.py
rename to tilert/models/deepseek_v3_2/ops/rmsnorm_expert_proj.py
index ce867a7..a12441a 100644
--- a/python/models/deepseek_v3_2/ops/rmsnorm_expert_proj.py
+++ b/tilert/models/deepseek_v3_2/ops/rmsnorm_expert_proj.py
@@ -1,6 +1,7 @@
"""RMSNormExpertProj operation module."""
from dataclasses import dataclass
+from enum import Enum
import torch
from torch import nn
@@ -8,7 +9,6 @@
from tilert.models.base import TileRTModule
from tilert.models.common import RMSNorm, init_func, linear
from tilert.models.deepseek_v3_2.model_args import ModelArgs
-from tilert.profiler.utils import parse_profile_log_tensor
from tilert.utils import get_profile_log_tensor
__all__ = [
@@ -40,9 +40,20 @@ def __call__(self) -> list[str]:
return [self.unproj_o_gamma, self.exp_proj_weights]
+class RMSNormExpertProjAlgorithm(Enum):
+ """RMSNormExpertProj algorithm."""
+
+ GENERAL = "general"
+
+
class RMSNormExpertProj(TileRTModule):
"""RMS Norm followed by expert projection."""
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [RMSNormExpertProjAlgorithm.GENERAL],
+ "glm_5": [RMSNormExpertProjAlgorithm.GENERAL],
+ }
+
def __init__(
self,
model_args: ModelArgs,
@@ -151,12 +162,10 @@ def tilert_forward(self, x_in: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor
self.tilert_proj_weight,
scores_out,
hidden_out,
+ self.model_args.arch_name,
+ "bf16",
self.profile_logs,
)
- if self.flag_enable_profiling_log:
- parse_profile_log_tensor(
- self.profile_logs, self.get_profile_log_path(), [(self.op_name, 0.0)]
- )
return hidden_out, scores_out
def __call__(self, x_in: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
diff --git a/tilert/models/deepseek_v3_2/ops/rmsnorm_head_proj.py b/tilert/models/deepseek_v3_2/ops/rmsnorm_head_proj.py
new file mode 100644
index 0000000..413a2e7
--- /dev/null
+++ b/tilert/models/deepseek_v3_2/ops/rmsnorm_head_proj.py
@@ -0,0 +1,296 @@
+"""RMSNormHeadProj operation module."""
+
+from dataclasses import dataclass
+from enum import Enum
+
+import torch
+
+from tilert.models.base import TileRTModule, TilertWeightsConverter
+from tilert.models.deepseek_v3_2.model_args import ModelArgs
+from tilert.utils import get_profile_log_tensor
+
+__all__ = [
+ "rmsnorm_head_proj",
+ "RMSNormHeadProj",
+ "RMSNormHeadProjTilertWeightsAlias",
+]
+
+
+def rmsnorm_head_proj(
+ hidden_in: torch.Tensor,
+ gamma_in: torch.Tensor,
+ weight_in: torch.Tensor,
+ hidden_rmsnorm_out: torch.Tensor,
+ logits_out: torch.Tensor,
+ profile_logs: torch.Tensor,
+ model_arch: str,
+ compute_kernel_type: str = "general",
+) -> None:
+ """RMS Norm Head Projection operation."""
+ torch.ops.tilert.rmsnorm_head_proj_op(
+ hidden_in,
+ gamma_in,
+ weight_in,
+ hidden_rmsnorm_out,
+ logits_out,
+ model_arch,
+ compute_kernel_type,
+ profile_logs,
+ )
+
+
+class RMSNormHeadProjAlgorithm(Enum):
+ """RMSNormHeadProj algorithm"""
+
+ GENERAL = "general"
+
+
+class RMSNormHeadProjWeightsConverter(TilertWeightsConverter):
+ """RMSNormHeadProj weights converter"""
+
+ @staticmethod
+ def tilert_to_tilert_native_bf16_warp_gemv(
+ tilert_weight_in: torch.Tensor,
+ ) -> torch.Tensor:
+ """Convert TILERT weights to TILERT native bf16 warp gemv weights."""
+ weights = tilert_weight_in.reshape(1010, 16, 7, 1024)
+ weights = weights.transpose(1, 2).reshape(7070, 16, 1024)
+ return weights.contiguous()
+
+ def convert_to_general(
+ self, weights_list: list[torch.Tensor]
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Convert the weights to general format.
+
+ Args:
+ weights_list: List of weights.
+
+ Returns:
+ Tuple of weights.
+ """
+ args = self.model_args
+ assert args.arch_name == "deepseek_v3_2" or args.arch_name == "glm_5"
+
+ with torch.inference_mode():
+ rmsnorm_gamma, mat_in = weights_list
+ logits_dim = mat_in.shape[-2]
+ dim = mat_in.shape[-1]
+ num_steps = dim // 1024
+ assert dim % 1024 == 0
+ weights = mat_in.reshape(logits_dim // 16, 16, num_steps, 1024)
+ weights = weights.transpose(1, 2).reshape(logits_dim // 16 * num_steps, 16, 1024)
+ return rmsnorm_gamma.float(), weights
+
+
+@dataclass
+class RMSNormHeadProjTilertWeightsAlias:
+ """TileRT weights alias for RMSNormHeadProj."""
+
+ model_norm_weight = "model.norm.weight"
+ lm_head_weight = "lm_head.weight"
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return [self.model_norm_weight, self.lm_head_weight]
+
+ def __call__(self) -> list[str]:
+ return self.tilert_tensor_alias
+
+
+class RMSNormHeadProj(TileRTModule):
+ """RMSNormHeadProj module"""
+
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [RMSNormHeadProjAlgorithm.GENERAL],
+ "glm_5": [RMSNormHeadProjAlgorithm.GENERAL],
+ }
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ device_id: int,
+ num_devices: int,
+ algorithm: RMSNormHeadProjAlgorithm = RMSNormHeadProjAlgorithm.GENERAL,
+ ):
+ super().__init__(
+ self.__class__.__name__,
+ model_args=model_args,
+ device_id=device_id,
+ num_devices=num_devices,
+ )
+
+ self.arch_name = self.model_args.arch_name
+ self.dim = self.model_args.dim
+ self.logits_dim = self.model_args.vocab_size
+ self.algorithm = algorithm
+ self.eps = self.model_args.eps
+
+ self.ref_rmsnorm_gamma: torch.Tensor | None = None
+ self.ref_head_proj: torch.Tensor | None = None
+
+ self.tilert_rmsnorm_gamma: torch.Tensor | None = None
+ self.tilert_head_proj: torch.Tensor | None = None
+
+ self.hidden_rmsnorm_out: torch.Tensor | None = None
+ self.hidden_out: torch.Tensor | None = None
+
+ self.profile_logs: torch.Tensor | None = None
+ self.is_init = False
+
+ self.tilert_weights_alias = RMSNormHeadProjTilertWeightsAlias()
+
+ self.ref_tensor_alias: list[str] = [
+ "model.norm.weight",
+ "lm_head.weight",
+ ]
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return self.tilert_weights_alias()
+
+ def get_weights_list(self) -> list[torch.Tensor]:
+ """
+ Get the weights list.
+
+ Returns:
+ List of weights.
+ """
+ return [self.tilert_rmsnorm_gamma, self.tilert_head_proj]
+
+ def device_sharding(
+ self,
+ weights_dict: dict[str, torch.Tensor],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Device sharding.
+
+ Args:
+ weights_dict: Dictionary of weights.
+ key_prefix: Key prefix.
+ Returns:
+ Tuple of weights.
+ """
+ rmsnorm_gamma_key = "model.norm.weight"
+ head_proj_key = "lm_head.weight"
+ rmsnorm_gamma = weights_dict[rmsnorm_gamma_key][None, ...]
+ rmsnorm_gamma = rmsnorm_gamma.repeat(self.num_devices, 1)
+ head_proj = weights_dict[head_proj_key]
+
+ head_proj = head_proj.reshape(self.num_devices, -1, self.dim)
+ return rmsnorm_gamma.contiguous(), head_proj.contiguous()
+
+ def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ """
+ Initialize the reference weights.
+
+ Args:
+ state_dict: State dictionary.
+ device_id: Device ID.
+ """
+ sharded_list = self.device_sharding(state_dict)
+
+ gamma, head_proj = sharded_list[0][self.device_id], sharded_list[1][self.device_id]
+ self.ref_rmsnorm_gamma = gamma
+ self.ref_head_proj = head_proj
+
+ def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ """
+ Initialize the tilert weights.
+
+ Args:
+ state_dict: State dictionary.
+ """
+ assert self.algorithm is not None
+ self.tilert_rmsnorm_gamma, self.tilert_head_proj = RMSNormHeadProjWeightsConverter(
+ self.model_args, self.num_devices
+ ).dispatch(self.algorithm, [state_dict[alias] for alias in self.tilert_weights_alias()])
+
+ def init_tilert_vars(self, batch_size: int, seq_len: int) -> None:
+ """
+ Initialize the tilert variables.
+
+ Args:
+ batch_size: Batch size.
+ seq_len: Sequence length.
+ """
+ self.hidden_rmsnorm_out = torch.zeros(
+ (batch_size, seq_len, self.dim),
+ dtype=torch.bfloat16,
+ device=f"cuda:{self.device_id}",
+ )
+ self.hidden_out = torch.zeros(
+ (batch_size, seq_len, self.logits_dim // self.num_devices),
+ dtype=torch.float32,
+ device=f"cuda:{self.device_id}",
+ )
+ self.profile_logs = get_profile_log_tensor(device=f"cuda:{self.device_id}")
+ self.is_init = True
+
+ def init_random_weights(self, device_id: int | None = None) -> None:
+ """Initialize the random weights."""
+ if device_id is None:
+ device_id = self.device_id
+ rmsnorm_gamma = torch.randn(self.dim, dtype=torch.float32, device=f"cuda:{device_id}")
+ head_proj = torch.randn(
+ self.logits_dim, self.dim, dtype=torch.bfloat16, device=f"cuda:{device_id}"
+ )
+
+ tensor_list = [
+ rmsnorm_gamma,
+ head_proj,
+ ]
+ state_dict = dict(zip(self.ref_tensor_alias, tensor_list))
+
+ self.init_reference_weights(state_dict)
+ sharded_list = self.device_sharding(state_dict)
+ sharded_state_dict = {
+ alias: sharded_list[i][self.device_id]
+ for i, alias in enumerate(self.tilert_weights_alias())
+ }
+ self.init_tilert_weights(sharded_state_dict)
+
+ def golden_forward(
+ self,
+ hidden_in: torch.Tensor,
+ ) -> torch.Tensor:
+ """
+ Forward pass for the down-project module.
+
+ Args:
+ hidden_in: Input hidden.
+
+ Returns:
+ Output tensor.
+ """
+ assert self.ref_rmsnorm_gamma is not None
+ assert self.ref_head_proj is not None
+ bsz = hidden_in.shape[0]
+ assert bsz == 1
+ hidden_rmsnorm = torch.nn.functional.rms_norm(
+ hidden_in.float(), [hidden_in.size(-1)], self.ref_rmsnorm_gamma, self.eps
+ )
+ return hidden_rmsnorm.float() @ self.ref_head_proj.T.float()
+
+ def tilert_forward(
+ self,
+ hidden_in: torch.Tensor,
+ ) -> torch.Tensor:
+ assert self.hidden_out is not None
+
+ rmsnorm_head_proj(
+ hidden_in,
+ self.tilert_rmsnorm_gamma,
+ self.tilert_head_proj,
+ self.hidden_rmsnorm_out,
+ self.hidden_out,
+ self.profile_logs,
+ model_arch=self.model_args.arch_name,
+ )
+ return self.hidden_out
+
+ def __call__(
+ self,
+ hidden_in: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.golden_forward(hidden_in)
diff --git a/python/models/deepseek_v3_2/ops/rmsnorm_kv.py b/tilert/models/deepseek_v3_2/ops/rmsnorm_kv.py
similarity index 88%
rename from python/models/deepseek_v3_2/ops/rmsnorm_kv.py
rename to tilert/models/deepseek_v3_2/ops/rmsnorm_kv.py
index d9c9af0..fcc3464 100644
--- a/python/models/deepseek_v3_2/ops/rmsnorm_kv.py
+++ b/tilert/models/deepseek_v3_2/ops/rmsnorm_kv.py
@@ -1,12 +1,12 @@
"""RMSNormKV operation module."""
from dataclasses import dataclass
+from enum import Enum
import torch
from tilert.models.base import TileRTModule
from tilert.models.deepseek_v3_2.model_args import ModelArgs
-from tilert.profiler.utils import parse_profile_log_tensor
from tilert.utils import get_profile_log_tensor
__all__ = [
@@ -23,6 +23,8 @@ def rmsnorm_kv(
cur_pos: torch.Tensor,
kv_cache: torch.Tensor,
profile_logs: torch.Tensor,
+ model_arch: str,
+ compute_kernel_type: str = "general",
) -> None:
"""
Define the RMSNormKV operation.
@@ -33,8 +35,12 @@ def rmsnorm_kv(
cur_pos: Current position tensor.
kv_cache: Output tensor.
profile_logs: Profile logs tensor.
+ model_arch: Model architecture string.
+ compute_kernel_type: Compute kernel type string.
"""
- torch.ops.tilert.rmsnorm_kv_op(kv, gamma, cur_pos, kv_cache, profile_logs)
+ torch.ops.tilert.rmsnorm_kv_op(
+ kv, gamma, cur_pos, kv_cache, model_arch, compute_kernel_type, profile_logs
+ )
@dataclass
@@ -65,9 +71,20 @@ def __call__(self) -> list[str]:
return self.tilert_tensor_alias
+class KVRMSNormAlgorithm(Enum):
+ """KVRMSNorm algorithm."""
+
+ GENERAL = "general"
+
+
class KVRMSNorm(TileRTModule):
"""KVRMSNorm module: RMSNorm on KV tensor with in-place write to kv_cache."""
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [KVRMSNormAlgorithm.GENERAL],
+ "glm_5": [KVRMSNormAlgorithm.GENERAL],
+ }
+
def __init__(
self,
model_args: ModelArgs,
@@ -170,11 +187,14 @@ def tilert_forward(
assert self.tilert_kv_norm_weight is not None
assert self.profile_logs is not None
cur_pos = torch.tensor([start_pos], dtype=torch.int32, device=kv.device)
- rmsnorm_kv(kv, self.tilert_kv_norm_weight, cur_pos, kv_cache[:bsz], self.profile_logs)
- if self.flag_enable_profiling_log:
- parse_profile_log_tensor(
- self.profile_logs, self.get_profile_log_path(), [(self.op_name, 0.0)]
- )
+ rmsnorm_kv(
+ kv,
+ self.tilert_kv_norm_weight,
+ cur_pos,
+ kv_cache[:bsz],
+ self.profile_logs,
+ model_arch=self.model_args.arch_name,
+ )
def __call__(
self, kv: torch.Tensor, kv_cache: torch.Tensor, start_pos: int, bsz: int, seqlen: int
diff --git a/tilert/models/deepseek_v3_2/ops/rmsnorm_projq_wqb.py b/tilert/models/deepseek_v3_2/ops/rmsnorm_projq_wqb.py
new file mode 100644
index 0000000..496df47
--- /dev/null
+++ b/tilert/models/deepseek_v3_2/ops/rmsnorm_projq_wqb.py
@@ -0,0 +1,540 @@
+"""RmsnormProjqWqb operation module."""
+
+import math
+from dataclasses import dataclass
+from enum import Enum
+
+import torch
+
+from tilert.models.base import TileRTModule, TilertWeightsConverter
+from tilert.models.common import weight_dequant
+from tilert.models.deepseek_v3_2.model_args import ModelArgs
+from tilert.utils import get_profile_log_tensor
+
+__all__ = [
+ "RmsnormProjqWqb",
+ "RmsnormProjqWqbAlgorithm",
+ "RmsnormProjqWqbWeightsConverter",
+]
+
+
+def rmsnorm_projq_wqb_op(
+ q: torch.Tensor,
+ wq_b: torch.Tensor,
+ wq_b_scales: torch.Tensor,
+ q_norm_weight: torch.Tensor,
+ q_nope: torch.Tensor,
+ q_pe: torch.Tensor,
+ profile_logs: torch.Tensor,
+ algorithm: str,
+ model_arch: str,
+) -> None:
+ torch.ops.tilert.rmsnorm_proj_qb_op(
+ q,
+ wq_b,
+ wq_b_scales,
+ q_norm_weight,
+ q_nope,
+ q_pe,
+ model_arch,
+ algorithm,
+ profile_logs,
+ torch.empty(0, dtype=torch.int64, device=q.device),
+ )
+
+
+class RmsnormProjqWqbAlgorithm(Enum):
+ """RmsnormProjqWqb algorithm."""
+
+ FP16MMA = "fp16mma"
+ BF16MMA = "bf16mma"
+
+
+class RmsnormProjqWqbWeightsConverter(TilertWeightsConverter):
+ """Weights converter for RmsnormProjqWqb.
+
+ Supports configurations where n_heads is not evenly divisible by
+ num_devices; in that case n_local_heads is padded and padded head
+ weight rows are zero-filled.
+ """
+
+ kBf16NumCtas = 80
+ kGemvPageSize = 8
+
+ def __init__(self, model_args: ModelArgs, num_devices: int):
+ super().__init__(model_args=model_args, num_devices=num_devices)
+
+ self.proc_groups = 8
+ self.repeat = 16
+
+ self.block_size = self.model_args.block_size
+
+ self.qk_nope_head_dim = self.model_args.qk_nope_head_dim
+ self.qk_rope_head_dim = self.model_args.qk_rope_head_dim
+ self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
+
+ self.needs_padding = self.model_args.n_heads % num_devices != 0
+ self.n_local_heads = self._compute_n_local_heads(
+ self.model_args.n_heads, num_devices, self.qk_head_dim
+ )
+
+ self.q_lora_dim = self.model_args.q_lora_rank
+ self.q_lora_qdim = self.q_lora_dim // self.block_size
+
+ self.qk_dim = self.qk_head_dim * self.n_local_heads
+ self.qk_qdim = self.qk_dim // self.block_size
+
+ assert self.qk_dim % (self.kBf16NumCtas * self.kGemvPageSize) == 0, (
+ f"qk_dim ({self.qk_dim}) must be divisible by "
+ f"kBf16NumCtas * kGemvPageSize ({self.kBf16NumCtas * self.kGemvPageSize})"
+ )
+ assert self.qk_dim % self.block_size == 0, (
+ f"qk_dim ({self.qk_dim}) must be divisible by block_size ({self.block_size}) "
+ f"for scale alignment"
+ )
+
+ @classmethod
+ def _compute_n_local_heads(cls, n_total_heads: int, num_devices: int, qk_head_dim: int) -> int:
+ """Compute padded n_local_heads per device."""
+ if n_total_heads % num_devices == 0:
+ return n_total_heads // num_devices
+
+ base = math.ceil(n_total_heads / num_devices)
+ align_unit = cls.kBf16NumCtas * cls.kGemvPageSize
+ g = math.gcd(qk_head_dim, align_unit)
+ head_align = align_unit // g
+ return math.ceil(base / head_align) * head_align
+
+ @staticmethod
+ def _redistribute_heads(
+ wq_b_full: torch.Tensor,
+ wq_b_scale_full: torch.Tensor,
+ n_total_heads: int,
+ n_local_heads: int,
+ num_devices: int,
+ qk_head_dim: int,
+ block_size: int,
+ ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
+ """Redistribute heads across devices with padding.
+
+ Args:
+ wq_b_full: [n_total_heads * qk_head_dim, q_lora_dim] full weight.
+ wq_b_scale_full: [n_total_heads * qk_head_dim // block_size, q_lora_qdim] full scale.
+ n_total_heads: Total number of heads (e.g. 128).
+ n_local_heads: Target heads per GPU (padded, e.g. 20).
+ num_devices: Number of devices (e.g. 7).
+ qk_head_dim: Head dimension (e.g. 192).
+ block_size: Quantization block size (e.g. 128).
+
+ Returns:
+ Lists of per-device (wq_b, wq_b_scale) with shape
+ [n_local_heads * qk_head_dim, q_lora_dim] and
+ [n_local_heads * qk_head_dim // block_size, q_lora_qdim].
+ """
+ total_rows = n_total_heads * qk_head_dim
+ rows_per_dev = n_local_heads * qk_head_dim
+ scale_rows_per_dev = rows_per_dev // block_size
+ total_scale_rows = total_rows // block_size
+
+ q_lora_dim = wq_b_full.shape[-1]
+ q_lora_qdim = wq_b_scale_full.shape[-1]
+
+ assert rows_per_dev % block_size == 0, (
+ f"n_local_heads * qk_head_dim ({rows_per_dev}) must be "
+ f"divisible by block_size ({block_size})"
+ )
+
+ wq_b_list = []
+ scale_list = []
+ for dev in range(num_devices):
+ start_row = dev * rows_per_dev
+ end_row = min(total_rows, start_row + rows_per_dev)
+ real_rows = max(0, end_row - start_row)
+
+ dev_wqb = torch.zeros(
+ rows_per_dev, q_lora_dim, dtype=wq_b_full.dtype, device=wq_b_full.device
+ )
+ if real_rows > 0:
+ dev_wqb[:real_rows] = wq_b_full[start_row:end_row]
+
+ start_scale = dev * scale_rows_per_dev
+ end_scale = min(total_scale_rows, start_scale + scale_rows_per_dev)
+ real_scale_rows = max(0, end_scale - start_scale)
+
+ dev_scale = torch.zeros(
+ scale_rows_per_dev,
+ q_lora_qdim,
+ dtype=wq_b_scale_full.dtype,
+ device=wq_b_scale_full.device,
+ )
+ if real_scale_rows > 0:
+ dev_scale[:real_scale_rows] = wq_b_scale_full[start_scale:end_scale]
+
+ wq_b_list.append(dev_wqb)
+ scale_list.append(dev_scale)
+
+ return wq_b_list, scale_list
+
+ @staticmethod
+ def _swizzle_mma_16x16(mat_in: torch.Tensor) -> torch.Tensor:
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 16
+ pre_shape = mat_in.shape[:-2]
+ mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 2).transpose(-4, -3).transpose(-5, -4)
+ return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 2).transpose(-3, -2)
+
+ @staticmethod
+ def _swizzle_mma_16x16_for_pages(
+ mat_in: torch.Tensor, q_lora_dim: int, pages: int
+ ) -> torch.Tensor:
+ """Swizzle 16xK matrix for paged MMA layout, any K divisible by 16."""
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == q_lora_dim
+ k_per_page = q_lora_dim // pages
+ n_k_tiles = k_per_page // 16
+ pre_shape = mat_in.shape[:-2]
+ mat_in = mat_in.reshape(*pre_shape, 16, pages, k_per_page).transpose(-3, -2)
+ mat_in = mat_in.reshape(*pre_shape, pages, 16, n_k_tiles, 16).transpose(-3, -2)
+ mat_in = RmsnormProjqWqbWeightsConverter._swizzle_mma_16x16(mat_in)
+ return mat_in.contiguous()
+
+ def _common_to_tilert_fp16mma(
+ self,
+ wq_b: torch.Tensor,
+ wq_b_scales_raw: torch.Tensor,
+ rmsnorm_gamma: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Convert common weights to the FP16 MMA layout."""
+ pages = 2
+ rows_per_cta = 32
+
+ qk_nope_dim = self.n_local_heads * self.qk_nope_head_dim
+ qk_pe_dim = self.n_local_heads * self.qk_rope_head_dim
+ nope_ctas = qk_nope_dim // rows_per_cta
+ pe_ctas = qk_pe_dim // rows_per_cta
+ num_ctas = nope_ctas + pe_ctas
+
+ wq_b_scales_f32 = wq_b_scales_raw.to(torch.float32)
+ wq_b_scales_f32 = (
+ wq_b_scales_f32.reshape(self.qk_qdim, 1, self.q_lora_qdim)
+ .repeat(1, self.block_size, 1)
+ .reshape(self.qk_dim, self.q_lora_qdim)
+ )
+
+ wq_b_scales_f32 = wq_b_scales_f32.reshape(
+ self.n_local_heads, self.qk_head_dim, self.q_lora_qdim
+ )
+ scale_nope = wq_b_scales_f32[:, : self.qk_nope_head_dim, :].reshape(-1, self.q_lora_qdim)
+ scale_pe = wq_b_scales_f32[:, self.qk_nope_head_dim :, :].reshape(-1, self.q_lora_qdim)
+
+ scale_nope = scale_nope.reshape(
+ nope_ctas, rows_per_cta, pages, self.q_lora_qdim // pages
+ ).transpose(1, 2)[:, :, 0, :]
+ scale_pe = scale_pe.reshape(
+ pe_ctas, rows_per_cta, pages, self.q_lora_qdim // pages
+ ).transpose(1, 2)[:, :, 0, :]
+
+ scales = torch.cat([scale_nope, scale_pe], dim=0)
+ scales_fp8 = scales.contiguous().view(torch.float8_e4m3fn)
+
+ wq_b = wq_b.reshape(self.n_local_heads, self.qk_head_dim, self.q_lora_dim)
+ wq_b_nope = wq_b[:, : self.qk_nope_head_dim, :].reshape(-1, self.q_lora_dim)
+ wq_b_pe = wq_b[:, self.qk_nope_head_dim :, :].reshape(-1, self.q_lora_dim)
+
+ wq_b_nope = wq_b_nope.reshape(nope_ctas, rows_per_cta // 16, 16, self.q_lora_dim)
+ wq_b_nope = RmsnormProjqWqbWeightsConverter._swizzle_mma_16x16_for_pages(
+ wq_b_nope, self.q_lora_dim, pages
+ )
+ wq_b_nope = (
+ wq_b_nope.reshape(nope_ctas, rows_per_cta // 16, pages, 16, -1)
+ .transpose(1, 2)
+ .reshape(nope_ctas, pages, rows_per_cta, -1)
+ )
+
+ wq_b_pe = wq_b_pe.reshape(pe_ctas, rows_per_cta // 16, 16, self.q_lora_dim)
+ wq_b_pe = RmsnormProjqWqbWeightsConverter._swizzle_mma_16x16_for_pages(
+ wq_b_pe, self.q_lora_dim, pages
+ )
+ wq_b_pe = (
+ wq_b_pe.reshape(pe_ctas, rows_per_cta // 16, pages, 16, -1)
+ .transpose(1, 2)
+ .reshape(pe_ctas, pages, rows_per_cta, -1)
+ )
+
+ weights = torch.cat([wq_b_nope, wq_b_pe], dim=0)
+ weights = weights.reshape(num_ctas, pages, -1)
+
+ scale_padding_size = 128 - scales_fp8.shape[-1]
+ scale_padding = torch.zeros(
+ num_ctas,
+ pages,
+ scale_padding_size,
+ dtype=torch.float8_e4m3fn,
+ device=wq_b.device,
+ )
+ tilert_wqb = torch.cat([weights, scales_fp8, scale_padding], dim=-1).contiguous()
+
+ tilert_wqb_scales = torch.zeros(1, dtype=torch.bfloat16)
+ tilert_gamma = rmsnorm_gamma.float().detach().clone()
+ return tilert_wqb, tilert_wqb_scales, tilert_gamma
+
+ def convert_to_bf16mma(
+ self, weights: list[torch.Tensor]
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Convert common-format weights to the BF16 MMA layout."""
+ return self.convert_to_fp16mma(weights)
+
+ def convert_to_fp16mma(
+ self, weights: list[torch.Tensor]
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Convert common-format weights to TileRT FP16 MMA layout."""
+ with torch.inference_mode():
+ wq_b, wq_b_scale, q_norm_weight = weights
+ return self._common_to_tilert_fp16mma(wq_b, wq_b_scale, q_norm_weight)
+
+
+@dataclass
+class RmsnormProjqWqbRefWeightsAlias:
+ """Reference weights alias for RmsnormProjqWqb."""
+
+ rmsnorm_gamma = "self_attn.q_a_layernorm.weight"
+ wqb_weights = "self_attn.q_b_proj.weight"
+ wqb_scales = "self_attn.q_b_proj.weight_scale_inv"
+
+ @property
+ def ref_tensor_alias(self) -> list[str]:
+ return [
+ self.rmsnorm_gamma,
+ self.wqb_weights,
+ self.wqb_scales,
+ ]
+
+ def __call__(self) -> list[str]:
+ return self.ref_tensor_alias
+
+
+@dataclass
+class RmsnormProjqWqbTilertWeightsAlias:
+ """TileRT weights alias for RmsnormProjqWqb."""
+
+ rmsnorm_gamma = "q_rmsnorm_gamma"
+ wqb_weights = "wqb_weights"
+ wqb_scales = "wqb_scales"
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return [
+ self.rmsnorm_gamma,
+ self.wqb_weights,
+ self.wqb_scales,
+ ]
+
+ def __call__(self) -> list[str]:
+ return self.tilert_tensor_alias
+
+
+class RmsnormProjqWqb(TileRTModule):
+ """RmsnormProjqWqb module: RMSNorm + Q projection (wq_b only)."""
+
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [
+ RmsnormProjqWqbAlgorithm.FP16MMA,
+ RmsnormProjqWqbAlgorithm.BF16MMA,
+ ],
+ "glm_5": [RmsnormProjqWqbAlgorithm.FP16MMA],
+ }
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ device_id: int,
+ num_devices: int = 7,
+ ref_weights_alias: RmsnormProjqWqbRefWeightsAlias | None = None,
+ ):
+ super().__init__(
+ self.__class__.__name__,
+ model_args=model_args,
+ device_id=device_id,
+ num_devices=num_devices,
+ )
+
+ self.tilert_weights_alias = RmsnormProjqWqbTilertWeightsAlias()
+ self.ref_weights_alias = (
+ ref_weights_alias if ref_weights_alias is not None else RmsnormProjqWqbRefWeightsAlias()
+ )
+
+ self.n_local_heads = RmsnormProjqWqbWeightsConverter._compute_n_local_heads(
+ model_args.n_heads,
+ num_devices,
+ model_args.qk_nope_head_dim + model_args.qk_rope_head_dim,
+ )
+ self.q_lora_rank = model_args.q_lora_rank
+ self.n_heads = model_args.n_heads
+ self.qk_nope_head_dim = model_args.qk_nope_head_dim
+ self.qk_rope_head_dim = model_args.qk_rope_head_dim
+ self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
+ self.qk_local_dim = self.qk_head_dim * self.n_local_heads
+
+ self.block_size = model_args.block_size
+ self.q_lora_qdim = self.q_lora_rank // self.block_size
+ self.qk_local_qdim = self.qk_local_dim // self.block_size
+ self.eps = model_args.eps
+
+ self.ref_q_norm: torch.Tensor | None = None
+ self.ref_wq_b: torch.Tensor | None = None
+
+ self.tilert_wq_b: torch.Tensor | None = None
+ self.tilert_wq_b_scales: torch.Tensor | None = None
+ self.tilert_q_norm_weight: torch.Tensor | None = None
+
+ self.q_nope: torch.Tensor | None = None
+ self.q_pe: torch.Tensor | None = None
+
+ self.profile_logs: torch.Tensor | None = None
+
+ def get_weights_list(self) -> list[torch.Tensor]:
+ return [self.tilert_q_norm_weight, self.tilert_wq_b, self.tilert_wq_b_scales]
+
+ def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ """Redistribute heads across devices with padding."""
+ gamma = weights_map[self.ref_weights_alias.rmsnorm_gamma][None, ...].repeat(
+ self.num_devices, 1
+ )
+
+ wq_b_full = weights_map[self.ref_weights_alias.wqb_weights]
+ wq_b_scale_full = weights_map[self.ref_weights_alias.wqb_scales]
+
+ wq_b_list, scale_list = RmsnormProjqWqbWeightsConverter._redistribute_heads(
+ wq_b_full,
+ wq_b_scale_full,
+ n_total_heads=self.n_heads,
+ n_local_heads=self.n_local_heads,
+ num_devices=self.num_devices,
+ qk_head_dim=self.qk_head_dim,
+ block_size=self.block_size,
+ )
+
+ sharded_wqb_weights = torch.stack(wq_b_list, dim=0)
+ sharded_wqb_scales = torch.stack(scale_list, dim=0)
+
+ return {
+ self.tilert_weights_alias.rmsnorm_gamma: gamma,
+ self.tilert_weights_alias.wqb_weights: sharded_wqb_weights,
+ self.tilert_weights_alias.wqb_scales: sharded_wqb_scales,
+ }
+
+ def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ """Initialize reference weights from common-format state dict."""
+ self.ref_q_norm = state_dict[self.ref_weights_alias.rmsnorm_gamma]
+
+ wq_b_full = state_dict[self.ref_weights_alias.wqb_weights]
+ wq_b_scale_full = state_dict[self.ref_weights_alias.wqb_scales]
+
+ wq_b_bf16_full = weight_dequant(wq_b_full, wq_b_scale_full)
+
+ total_rows = self.n_heads * self.qk_head_dim
+ rows_per_dev = self.n_local_heads * self.qk_head_dim
+ start_row = self.device_id * rows_per_dev
+ end_row = min(total_rows, start_row + rows_per_dev)
+ real_rows = max(0, end_row - start_row)
+
+ dev_wqb = torch.zeros(
+ rows_per_dev,
+ wq_b_bf16_full.shape[-1],
+ dtype=wq_b_bf16_full.dtype,
+ device=wq_b_bf16_full.device,
+ )
+ if real_rows > 0:
+ dev_wqb[:real_rows] = wq_b_bf16_full[start_row:end_row]
+
+ self.ref_wq_b = dev_wqb.contiguous()
+
+ def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ """Initialize TileRT weights from common-format state dict."""
+ weights = [
+ state_dict[self.tilert_weights_alias.wqb_weights],
+ state_dict[self.tilert_weights_alias.wqb_scales],
+ state_dict[self.tilert_weights_alias.rmsnorm_gamma],
+ ]
+ assert self.algorithm is not None, "Algorithm is not set"
+ self.tilert_wq_b, self.tilert_wq_b_scales, self.tilert_q_norm_weight = (
+ RmsnormProjqWqbWeightsConverter(self.model_args, self.num_devices).dispatch(
+ self.algorithm, weights
+ )
+ )
+
+ def init_random_weights(self) -> None:
+ """Initialize random reference and TileRT weights for testing."""
+ q_norm = torch.randn(self.q_lora_rank, dtype=torch.float32)
+
+ wq_b = torch.randn(self.qk_local_dim, self.q_lora_rank, dtype=torch.bfloat16).to(
+ torch.float8_e4m3fn
+ )
+ scale_dtype = torch.float32 if self.model_args.arch_name == "glm_5" else torch.bfloat16
+ wq_b_scale = torch.randn(self.qk_local_qdim, self.q_lora_qdim, dtype=scale_dtype)
+
+ self.ref_q_norm = q_norm
+ self.ref_wq_b = weight_dequant(wq_b, wq_b_scale).contiguous()
+
+ assert self.algorithm is not None, "Algorithm is not set"
+ weights = [wq_b, wq_b_scale, q_norm]
+ self.tilert_wq_b, self.tilert_wq_b_scales, self.tilert_q_norm_weight = (
+ RmsnormProjqWqbWeightsConverter(self.model_args, self.num_devices).dispatch(
+ self.algorithm, weights
+ )
+ )
+
+ def init_tilert_vars(self, batch_size: int, seq_len: int) -> None:
+ """Allocate TileRT output buffers."""
+ self.q_nope = torch.zeros(
+ batch_size, seq_len, self.n_local_heads, self.qk_nope_head_dim, dtype=torch.bfloat16
+ )
+ self.q_pe = torch.zeros(
+ batch_size, seq_len, self.n_local_heads, self.qk_rope_head_dim, dtype=torch.bfloat16
+ )
+ self.profile_logs = get_profile_log_tensor()
+ self.is_var_init = True
+
+ def golden_forward(self, q: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ """Reference forward: RMSNorm + linear projection (no iq)."""
+ assert self.ref_q_norm is not None
+ assert self.ref_wq_b is not None
+
+ bsz, seqlen, _ = q.shape
+ if bsz != 1 or seqlen not in [1, 2, 4]:
+ raise ValueError(f"Invalid batch size or sequence length: bsz={bsz}, seqlen={seqlen}")
+
+ qr = torch.nn.functional.rms_norm(q.float(), [q.size(-1)], self.ref_q_norm, self.eps).to(
+ q.dtype
+ )
+
+ q_out = torch.matmul(qr, self.ref_wq_b.T)
+ q_out = q_out.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
+ q_nope, q_pe = torch.split(q_out, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
+ return q_nope, q_pe
+
+ def tilert_forward(self, q: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ assert self.tilert_wq_b is not None
+ assert self.tilert_wq_b_scales is not None
+ assert self.tilert_q_norm_weight is not None
+ assert self.q_nope is not None
+ assert self.q_pe is not None
+ assert self.profile_logs is not None
+
+ bsz, seqlen, _ = q.shape
+ if bsz != 1 or seqlen not in [1, 2, 4]:
+ raise ValueError(f"Invalid batch size or sequence length: bsz={bsz}, seqlen={seqlen}")
+
+ assert self.algorithm is not None, "Algorithm is not set"
+
+ rmsnorm_projq_wqb_op(
+ q,
+ self.tilert_wq_b,
+ self.tilert_wq_b_scales,
+ self.tilert_q_norm_weight,
+ self.q_nope,
+ self.q_pe,
+ self.profile_logs,
+ self.algorithm.value,
+ model_arch=self.model_args.arch_name,
+ )
+
+ return self.q_nope, self.q_pe
diff --git a/tilert/models/deepseek_v3_2/ops/rmsnorm_projq_wqi.py b/tilert/models/deepseek_v3_2/ops/rmsnorm_projq_wqi.py
new file mode 100644
index 0000000..b91faf6
--- /dev/null
+++ b/tilert/models/deepseek_v3_2/ops/rmsnorm_projq_wqi.py
@@ -0,0 +1,340 @@
+"""RmsnormProjqWqi operation module (IQ-only projection)."""
+
+from dataclasses import dataclass
+from enum import Enum
+
+import torch
+from einops import rearrange
+
+from tilert.models.base import TileRTModule, TilertWeightsConverter
+from tilert.models.common import weight_dequant
+from tilert.models.deepseek_v3_2.model_args import ModelArgs
+from tilert.utils import get_profile_log_tensor
+
+__all__ = [
+ "RmsnormProjqWqi",
+ "RmsnormProjqWqiAlgorithm",
+ "RmsnormProjqWqiWeightsConverter",
+]
+
+
+def rmsnorm_projq_wqi_op(
+ q: torch.Tensor,
+ wqi: torch.Tensor,
+ wqi_scale: torch.Tensor,
+ rmsnorm_gamma: torch.Tensor,
+ iq: torch.Tensor,
+ profile_logs: torch.Tensor,
+ algorithm: str,
+ model_arch: str,
+) -> None:
+ torch.ops.tilert.rmsnorm_proj_qi_op(
+ q,
+ wqi,
+ wqi_scale,
+ rmsnorm_gamma,
+ iq,
+ model_arch,
+ algorithm,
+ profile_logs,
+ )
+
+
+class RmsnormProjqWqiAlgorithm(Enum):
+ """RmsnormProjqWqi algorithm."""
+
+ FP16MMA = "fp16mma"
+ BF16MMA = "bf16mma"
+
+
+class RmsnormProjqWqiWeightsConverter(TilertWeightsConverter):
+ """Weights converter: common format to TileRT format (IQ only)."""
+
+ def __init__(self, model_args: ModelArgs, num_devices: int):
+ super().__init__(model_args=model_args, num_devices=num_devices)
+
+ self.block_size = self.model_args.block_size
+ self.q_lora_dim = self.model_args.q_lora_rank
+ self.q_lora_qdim = self.q_lora_dim // self.block_size
+
+ self.index_n_heads = self.model_args.index_n_heads
+ self.index_head_dim = self.index_n_heads * self.model_args.index_head_dim
+ self.index_head_qdim = self.index_head_dim // self.block_size
+
+ @staticmethod
+ def _swizzle_mma_16x16(mat_in: torch.Tensor) -> torch.Tensor:
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 16
+ pre_shape = mat_in.shape[:-2]
+ mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 2).transpose(-4, -3).transpose(-5, -4)
+ return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 2).transpose(-3, -2)
+
+ @staticmethod
+ def _swizzle_mma_16x16_for_pages(
+ mat_in: torch.Tensor, q_lora_rank: int, pages: int
+ ) -> torch.Tensor:
+ """Swizzle 16xK matrix for paged MMA layout, any K divisible by 16."""
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == q_lora_rank
+ pre_shape = mat_in.shape[:-2]
+ k_per_page = q_lora_rank // pages
+ n_k_tiles = k_per_page // 16
+ mat_in = mat_in.reshape(*pre_shape, 16, pages, k_per_page).transpose(-3, -2)
+ mat_in = mat_in.reshape(*pre_shape, pages, 16, n_k_tiles, 16).transpose(-3, -2)
+ mat_in = RmsnormProjqWqiWeightsConverter._swizzle_mma_16x16(mat_in)
+ return mat_in.contiguous()
+
+ def _common_to_tilert_fp16mma(
+ self,
+ wqi: torch.Tensor,
+ wqi_scales: torch.Tensor,
+ rmsnorm_gamma: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Convert common weights to TileRT FP16 MMA layout (IQ only)."""
+ sms = 128
+ k_per_page = 1024 if self.model_args.arch_name == "glm_5" else 512
+ pages = self.q_lora_dim // k_per_page
+ iq_dim_per_sm = self.index_head_dim // sms
+
+ wqi_scales_f32 = wqi_scales.to(torch.float32)
+ wqi_scales_f32 = (
+ wqi_scales_f32.reshape(self.index_head_qdim, 1, self.q_lora_qdim)
+ .repeat(1, self.block_size, 1)
+ .reshape(self.index_head_dim, self.q_lora_qdim)
+ )
+ wqi_scales_f32 = wqi_scales_f32.reshape(
+ sms, iq_dim_per_sm, pages, self.q_lora_qdim // pages
+ ).transpose(1, 2)
+ wqi_scales_f32 = wqi_scales_f32[:, :, 0, :]
+ wqi_full_scales = wqi_scales_f32.contiguous().view(torch.float8_e4m3fn)
+
+ wqi_mat = wqi.reshape(sms, iq_dim_per_sm // 16, 16, self.q_lora_dim)
+ wqi_mat = RmsnormProjqWqiWeightsConverter._swizzle_mma_16x16_for_pages(
+ wqi_mat, self.q_lora_dim, pages
+ )
+ wqi_mat = (
+ wqi_mat.reshape(sms, iq_dim_per_sm // 16, pages, 16, -1)
+ .transpose(1, 2)
+ .reshape(sms, pages, iq_dim_per_sm, -1)
+ )
+ wqi_mat = wqi_mat.reshape(sms, pages, -1)
+
+ wqi_scales_padding = torch.zeros(
+ sms,
+ pages,
+ 128 - wqi_full_scales.shape[-1],
+ dtype=torch.float8_e4m3fn,
+ device=wqi.device,
+ )
+ tilert_wqi = torch.cat([wqi_mat, wqi_full_scales, wqi_scales_padding], dim=-1).contiguous()
+ tilert_wqi_scales = torch.zeros(1, dtype=torch.bfloat16)
+ tilert_gamma = rmsnorm_gamma.float().detach().clone()
+ return tilert_wqi, tilert_wqi_scales, tilert_gamma
+
+ def convert_to_bf16mma(
+ self, weights: list[torch.Tensor]
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Convert common-format weights to the BF16 MMA layout."""
+ return self.convert_to_fp16mma(weights)
+
+ def convert_to_fp16mma(
+ self, weights: list[torch.Tensor]
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Convert common-format weights to TileRT FP16 MMA layout.
+
+ Args:
+ weights: [wqi, wqi_scale, q_norm_weight].
+ """
+ with torch.inference_mode():
+ wqi, wqi_scale, q_norm_weight = weights
+ return self._common_to_tilert_fp16mma(wqi, wqi_scale, q_norm_weight)
+
+
+@dataclass
+class RmsnormProjqWqiRefWeightsAlias:
+ """Reference (HuggingFace) weights alias for RmsnormProjqWqi."""
+
+ rmsnorm_gamma = "self_attn.q_a_layernorm.weight"
+ wqi_weights = "self_attn.indexer.wq_b.weight"
+ wqi_scales = "self_attn.indexer.wq_b.weight_scale_inv"
+
+ @property
+ def ref_tensor_alias(self) -> list[str]:
+ return [self.rmsnorm_gamma, self.wqi_weights, self.wqi_scales]
+
+ def __call__(self) -> list[str]:
+ return self.ref_tensor_alias
+
+
+@dataclass
+class RmsnormProjqWqiTilertWeightsAlias:
+ """TileRT weights alias for RmsnormProjqWqi."""
+
+ rmsnorm_gamma = "q_rmsnorm_gamma_qi"
+ wqi_weights = "wqi_weights"
+ wqi_scales = "wqi_scales"
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return [self.rmsnorm_gamma, self.wqi_weights, self.wqi_scales]
+
+ def __call__(self) -> list[str]:
+ return self.tilert_tensor_alias
+
+
+class RmsnormProjqWqi(TileRTModule):
+ """RmsnormProjqWqi module: RMSNorm + W_qi projection (IQ only, GLM5 v2)."""
+
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [
+ RmsnormProjqWqiAlgorithm.FP16MMA,
+ RmsnormProjqWqiAlgorithm.BF16MMA,
+ ],
+ "glm_5": [RmsnormProjqWqiAlgorithm.FP16MMA],
+ }
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ device_id: int,
+ num_devices: int,
+ ):
+ super().__init__(
+ self.__class__.__name__,
+ model_args=model_args,
+ device_id=device_id,
+ num_devices=num_devices,
+ )
+
+ self.tilert_weights_alias = RmsnormProjqWqiTilertWeightsAlias()
+ self.ref_weights_alias = RmsnormProjqWqiRefWeightsAlias()
+
+ self.q_lora_rank = model_args.q_lora_rank
+ self.index_n_heads = model_args.index_n_heads
+ self.head_dim = model_args.index_head_dim
+ self.index_head_dim = model_args.index_n_heads * model_args.index_head_dim
+
+ self.block_size = model_args.block_size
+ self.q_lora_qdim = self.q_lora_rank // self.block_size
+ self.index_head_qdim = self.index_head_dim // self.block_size
+ self.eps = model_args.eps
+
+ self.ref_q_norm: torch.Tensor | None = None
+ self.ref_wqi: torch.Tensor | None = None
+
+ self.tilert_wqi: torch.Tensor | None = None
+ self.tilert_wqi_scales: torch.Tensor | None = None
+ self.tilert_q_norm_weight: torch.Tensor | None = None
+
+ self.iq: torch.Tensor | None = None
+ self.profile_logs: torch.Tensor | None = None
+
+ def get_weights_list(self) -> list[torch.Tensor]:
+ return [self.tilert_q_norm_weight, self.tilert_wqi, self.tilert_wqi_scales]
+
+ def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ """Replicate IQ weights across devices (no per-head redistribution needed)."""
+ gamma = (
+ weights_map[self.ref_weights_alias.rmsnorm_gamma][None, ...]
+ .float()
+ .repeat(self.num_devices, 1)
+ )
+ wqi_weights = weights_map[self.ref_weights_alias.wqi_weights][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ wqi_scales = weights_map[self.ref_weights_alias.wqi_scales][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ return {
+ self.tilert_weights_alias.rmsnorm_gamma: gamma,
+ self.tilert_weights_alias.wqi_weights: wqi_weights,
+ self.tilert_weights_alias.wqi_scales: wqi_scales,
+ }
+
+ def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ """Initialize reference weights from common-format state dict."""
+ self.ref_q_norm = state_dict[self.tilert_weights_alias.rmsnorm_gamma]
+ wqi = weight_dequant(
+ state_dict[self.tilert_weights_alias.wqi_weights],
+ state_dict[self.tilert_weights_alias.wqi_scales],
+ )
+ self.ref_wqi = wqi.contiguous()
+
+ def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ """Initialize TileRT weights from common-format state dict."""
+ weights = [
+ state_dict[self.tilert_weights_alias.wqi_weights],
+ state_dict[self.tilert_weights_alias.wqi_scales],
+ state_dict[self.tilert_weights_alias.rmsnorm_gamma],
+ ]
+ assert self.algorithm is not None, "Algorithm is not set"
+ self.tilert_wqi, self.tilert_wqi_scales, self.tilert_q_norm_weight = (
+ RmsnormProjqWqiWeightsConverter(self.model_args, self.num_devices).dispatch(
+ self.algorithm, weights
+ )
+ )
+
+ def init_random_weights(self) -> None:
+ """Initialize random reference and TileRT weights for testing."""
+ q_norm = torch.randn(self.q_lora_rank, dtype=torch.float32)
+ wqi = torch.randn(self.index_head_dim, self.q_lora_rank, dtype=torch.bfloat16).to(
+ torch.float8_e4m3fn
+ )
+ scale_dtype = torch.float32 if self.model_args.arch_name == "glm_5" else torch.bfloat16
+ wqi_scale = torch.randn(self.index_head_qdim, self.q_lora_qdim, dtype=scale_dtype)
+
+ ref_state = {
+ self.tilert_weights_alias.rmsnorm_gamma: q_norm,
+ self.tilert_weights_alias.wqi_weights: wqi,
+ self.tilert_weights_alias.wqi_scales: wqi_scale,
+ }
+
+ self.init_reference_weights(ref_state)
+ self.init_tilert_weights(ref_state)
+
+ def init_tilert_vars(self, batch_size: int, seq_len: int) -> None:
+ """Allocate TileRT output buffers."""
+ self.iq = torch.zeros(
+ batch_size, seq_len, self.index_n_heads, self.head_dim, dtype=torch.bfloat16
+ )
+ self.profile_logs = get_profile_log_tensor()
+ self.is_var_init = True
+
+ def golden_forward(self, q: torch.Tensor) -> torch.Tensor:
+ """Reference forward: RMSNorm + W_qi_b linear projection."""
+ assert self.ref_q_norm is not None
+ assert self.ref_wqi is not None
+
+ bsz, seqlen, _ = q.shape
+ if bsz != 1 or seqlen not in [1, 2, 4, 8]:
+ raise ValueError(f"Invalid batch size or sequence length: bsz={bsz}, seqlen={seqlen}")
+
+ qr = torch.nn.functional.rms_norm(q.float(), [q.size(-1)], self.ref_q_norm, self.eps).to(
+ q.dtype
+ )
+
+ return rearrange(torch.matmul(qr, self.ref_wqi.T), "b s (h d) -> b s h d", d=self.head_dim)
+
+ def tilert_forward(self, q: torch.Tensor) -> torch.Tensor:
+ assert self.tilert_wqi is not None
+ assert self.tilert_wqi_scales is not None
+ assert self.tilert_q_norm_weight is not None
+ assert self.iq is not None
+ assert self.profile_logs is not None
+
+ bsz, seqlen, _ = q.shape
+ if bsz != 1 or seqlen not in [1, 2, 4, 8]:
+ raise ValueError(f"Invalid batch size or sequence length: bsz={bsz}, seqlen={seqlen}")
+
+ assert self.algorithm is not None, "Algorithm is not set"
+
+ rmsnorm_projq_wqi_op(
+ q,
+ self.tilert_wqi,
+ self.tilert_wqi_scales,
+ self.tilert_q_norm_weight,
+ self.iq,
+ self.profile_logs,
+ self.algorithm.value,
+ model_arch=self.model_args.arch_name,
+ )
+
+ return self.iq
diff --git a/tilert/models/deepseek_v3_2/ops/rmsnorm_projx_wqakis.py b/tilert/models/deepseek_v3_2/ops/rmsnorm_projx_wqakis.py
new file mode 100644
index 0000000..4fd5b12
--- /dev/null
+++ b/tilert/models/deepseek_v3_2/ops/rmsnorm_projx_wqakis.py
@@ -0,0 +1,341 @@
+"""RMSNormProjxWqakis operation module."""
+
+from enum import Enum
+
+import torch
+
+from tilert.models.base import TileRTModule, TilertWeightsConverter
+from tilert.models.common import weight_dequant
+from tilert.models.deepseek_v3_2.model_args import ModelArgs
+from tilert.models.deepseek_v3_2.ops.projx_wis import projx_wis
+from tilert.models.deepseek_v3_2.ops.projx_wqaki import (
+ ProjxWqakiWeightsConverter,
+ projx_wqaki,
+)
+from tilert.models.deepseek_v3_2.ops.rmsnorm_quant import rmsnorm_quant
+from tilert.utils import get_profile_log_tensor
+
+__all__ = [
+ "RMSNormProjxWqakis",
+]
+
+
+class RMSNormProjxWqakisWeightsConverter(TilertWeightsConverter):
+ """Weight converter for RMSNormProjxWqakis (decoupled FP8 MMA)."""
+
+ def __init__(self, model_args: ModelArgs, num_devices: int):
+ super().__init__(model_args, num_devices)
+
+ def convert_to_decoupled(
+ self, weights: list[torch.Tensor]
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Convert weights to decoupled FP8 MMA format.
+
+ Args:
+ weights: [gamma, wq_a, wq_a_scale, wki, wki_scale, wis, wis_scale]
+
+ Returns:
+ (wqaki_packed, wis_bf16, gamma)
+ """
+ arch_name = self.model_args.arch_name
+ x_rmsnorm_gamma, wq_a, wq_a_scale, wki, wki_scale, wis, _wis_scale = weights
+
+ if arch_name == "deepseek_v3_2":
+ wqaki_packed = ProjxWqakiWeightsConverter.convert_dsv32(
+ wq_a, wq_a_scale, wki, wki_scale
+ )
+ elif arch_name == "glm_5":
+ wqaki_packed = ProjxWqakiWeightsConverter.convert_glm5_68cta(
+ wq_a, wq_a_scale, wki, wki_scale
+ )
+ else:
+ raise ValueError(f"Unsupported architecture: {arch_name}")
+
+ wis_bf16 = wis.to(torch.bfloat16)
+ return wqaki_packed, wis_bf16, x_rmsnorm_gamma.float()
+
+
+class RMSNormProjxWqakisRefWeightsAlias:
+ """Reference weight aliases for RMSNormProjxWqakis."""
+
+ x_rmsnorm_gamma = "input_layernorm.weight"
+ q_a_weights = "self_attn.q_a_proj.weight"
+ q_a_scales = "self_attn.q_a_proj.weight_scale_inv"
+ wk_weights = "self_attn.indexer.wk.weight"
+ wk_scales = "self_attn.indexer.wk.weight_scale_inv"
+ wis_weights = "self_attn.indexer.weights_proj.weight"
+
+ @property
+ def ref_tensor_alias(self) -> list[str]:
+ return [
+ self.x_rmsnorm_gamma,
+ self.q_a_weights,
+ self.q_a_scales,
+ self.wk_weights,
+ self.wk_scales,
+ self.wis_weights,
+ ]
+
+ def __call__(self) -> list[str]:
+ return self.ref_tensor_alias
+
+
+class RMSNormProjxWqakisTilertWeightsAlias:
+ """Tilert weight aliases for RMSNormProjxWqakis."""
+
+ x_rmsnorm_gamma = "x_rmsnorm_gamma"
+ q_a_weights = "q_a_weights"
+ q_a_scales = "q_a_scales"
+ wk_weights = "wk_weights"
+ wk_scales = "wk_scales"
+ wis_weights = "wis_weights"
+ wis_scales = "wis_scales"
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return [
+ self.x_rmsnorm_gamma,
+ self.q_a_weights,
+ self.q_a_scales,
+ self.wk_weights,
+ self.wk_scales,
+ self.wis_weights,
+ self.wis_scales,
+ ]
+
+ def __call__(self) -> list[str]:
+ return self.tilert_tensor_alias
+
+
+class RMSNormProjxWqakisAlgorithm(Enum):
+ """RMSNormProjxWqakis algorithm."""
+
+ FP8MMA = "fp8mma"
+
+
+class RMSNormProjxWqakis(TileRTModule):
+ """Decoupled RMSNorm + GEMV(W_q_a, W_ki, W_is)."""
+
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [RMSNormProjxWqakisAlgorithm.FP8MMA],
+ "glm_5": [RMSNormProjxWqakisAlgorithm.FP8MMA],
+ }
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ num_devices: int,
+ device_id: int,
+ ref_weights_alias: RMSNormProjxWqakisRefWeightsAlias | None = None,
+ ):
+ super().__init__(
+ self.__class__.__name__,
+ model_args=model_args,
+ num_devices=num_devices,
+ device_id=device_id,
+ )
+
+ self.tilert_weights_alias = RMSNormProjxWqakisTilertWeightsAlias()
+ self.ref_weights_alias = (
+ ref_weights_alias
+ if ref_weights_alias is not None
+ else RMSNormProjxWqakisRefWeightsAlias()
+ )
+
+ self.arch_name = self.model_args.arch_name
+ self.dim = self.model_args.dim
+ self.q_lora_rank = self.model_args.q_lora_rank
+ self.idx_head_dim = self.model_args.index_head_dim
+ self.idx_score_dim = self.model_args.index_n_heads
+ self.block_size = self.model_args.block_size
+ self.eps = self.model_args.eps
+
+ self.ref_norm_gamma: torch.Tensor | None = None
+ self.ref_wq_a: torch.Tensor | None = None
+ self.ref_wki: torch.Tensor | None = None
+ self.ref_wis: torch.Tensor | None = None
+
+ self.tilert_norm_gamma: torch.Tensor | None = None
+ self.tilert_wqakis: torch.Tensor | None = None
+ self.tilert_wis: torch.Tensor | None = None
+
+ self.q_out: torch.Tensor | None = None
+ self.ki_out: torch.Tensor | None = None
+ self.idx_scores_out: torch.Tensor | None = None
+ self.x_rmsnorm_out: torch.Tensor | None = None
+ self.x_rmsnorm_quant_out: torch.Tensor | None = None
+ self.x_rmsnorm_quant_scale_out: torch.Tensor | None = None
+ self.profile_logs: torch.Tensor | None = None
+ self.is_init = False
+
+ if self.arch_name == "glm_5":
+ self.compute_kernel_type = "fp8mma_68cta"
+ else:
+ self.compute_kernel_type = "fp8mma"
+
+ self.tilert_tensor_alias: list[str] = [
+ "x_rmsnorm_gamma",
+ "qakis_weights",
+ "qakis_scales",
+ ]
+
+ def get_weights_list(self) -> list[torch.Tensor]:
+ return [self.tilert_norm_gamma, self.tilert_wqakis, self.tilert_wis]
+
+ def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ """Repeat weights for device sharding."""
+ input_layernorm_weight = (
+ weights_map[self.ref_weights_alias.x_rmsnorm_gamma][None, ...]
+ .float()
+ .repeat(self.num_devices, 1)
+ )
+ q_a_proj_weight = weights_map[self.ref_weights_alias.q_a_weights][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ q_a_proj_weight_scale = weights_map[self.ref_weights_alias.q_a_scales][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ wk_weight = weights_map[self.ref_weights_alias.wk_weights][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ wk_weight_scale = weights_map[self.ref_weights_alias.wk_scales][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ wis_weight = weights_map[self.ref_weights_alias.wis_weights][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ is_n_rows = weights_map[self.ref_weights_alias.wis_weights].shape[0]
+ is_scale_rows = (is_n_rows + self.block_size - 1) // self.block_size
+ is_scale_cols = self.dim // self.block_size
+ wis_weight_scale = torch.ones(
+ self.num_devices, is_scale_rows, is_scale_cols, dtype=torch.bfloat16
+ )
+ return {
+ self.tilert_weights_alias.x_rmsnorm_gamma: input_layernorm_weight,
+ self.tilert_weights_alias.q_a_weights: q_a_proj_weight,
+ self.tilert_weights_alias.q_a_scales: q_a_proj_weight_scale,
+ self.tilert_weights_alias.wk_weights: wk_weight,
+ self.tilert_weights_alias.wk_scales: wk_weight_scale,
+ self.tilert_weights_alias.wis_weights: wis_weight,
+ self.tilert_weights_alias.wis_scales: wis_weight_scale,
+ }
+
+ def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ aliases = self.ref_weights_alias()
+ self.ref_norm_gamma = state_dict[aliases[0]]
+ self.ref_wq_a = weight_dequant(state_dict[aliases[1]], state_dict[aliases[2]])
+ self.ref_wki = weight_dequant(state_dict[aliases[3]], state_dict[aliases[4]])
+ self.ref_wis = state_dict[aliases[5]].to(torch.bfloat16)
+
+ assert self.ref_norm_gamma.shape[-1] == self.dim
+ assert self.ref_wq_a.shape == (self.q_lora_rank, self.dim)
+ assert self.ref_wki.shape == (self.idx_head_dim, self.dim)
+ assert self.ref_wis.shape == (self.idx_score_dim, self.dim)
+
+ def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ tilert_aliases = self.tilert_weights_alias()
+ weights_list = [state_dict[alias] for alias in tilert_aliases]
+ converter = RMSNormProjxWqakisWeightsConverter(self.model_args, self.num_devices)
+ result = converter.convert_to_decoupled(weights_list)
+ self.tilert_wqakis, self.tilert_wis, self.tilert_norm_gamma = result
+
+ def init_tilert_vars(self, batch_size: int, seq_len: int) -> None:
+ self.q_out = torch.zeros((batch_size, seq_len, self.q_lora_rank), dtype=torch.bfloat16)
+ self.ki_out = torch.zeros((batch_size, seq_len, self.idx_head_dim), dtype=torch.bfloat16)
+ self.idx_scores_out = torch.zeros(
+ (batch_size, seq_len, self.idx_score_dim), dtype=torch.bfloat16
+ )
+ self.x_rmsnorm_out = torch.zeros((batch_size, seq_len, self.dim), dtype=torch.bfloat16)
+ self.x_rmsnorm_quant_out = torch.zeros(
+ (batch_size, seq_len, self.dim), dtype=torch.float8_e4m3fn
+ )
+ self.x_rmsnorm_quant_scale_out = torch.zeros(
+ (batch_size, seq_len, self.dim // self.block_size), dtype=torch.float32
+ )
+ self.profile_logs = get_profile_log_tensor()
+ self.is_init = True
+
+ def init_random_weights(self) -> None:
+ bs = self.block_size
+ dim_scale_dim = self.dim // bs
+ q_scale_dim = (self.q_lora_rank + bs - 1) // bs
+ ki_scale_dim = (self.idx_head_dim + bs - 1) // bs
+ scale_dtype = torch.float32 if self.arch_name == "glm_5" else torch.bfloat16
+
+ tensor_list = [
+ torch.randn(self.dim, dtype=torch.float32),
+ torch.randn(self.q_lora_rank, self.dim, dtype=torch.bfloat16).to(torch.float8_e4m3fn),
+ torch.randn(q_scale_dim, dim_scale_dim, dtype=scale_dtype),
+ torch.randn(self.idx_head_dim, self.dim, dtype=torch.bfloat16).to(torch.float8_e4m3fn),
+ torch.randn(ki_scale_dim, dim_scale_dim, dtype=scale_dtype),
+ torch.randn(self.idx_score_dim, self.dim, dtype=torch.bfloat16),
+ ]
+ ref_state_dict = dict(zip(self.ref_weights_alias(), tensor_list))
+ self.init_reference_weights(ref_state_dict)
+ self.init_tilert_weights(
+ {_k: _v[self.device_id] for _k, _v in self.device_sharding(ref_state_dict).items()}
+ )
+
+ def golden_forward(
+ self,
+ x: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Pure PyTorch reference: RMSNorm -> q, ki, idx_scores."""
+ assert self.ref_norm_gamma is not None
+ assert self.ref_wq_a is not None
+ assert self.ref_wki is not None
+ assert self.ref_wis is not None
+
+ x_rmsnorm = torch.nn.functional.rms_norm(
+ x.float(), [x.size(-1)], self.ref_norm_gamma, self.eps
+ )
+ q_out = torch.matmul(x_rmsnorm.float(), self.ref_wq_a.transpose(0, 1).float())
+ ki_out = torch.matmul(x_rmsnorm.float(), self.ref_wki.transpose(0, 1).float())
+ idx_scores_out = torch.matmul(x_rmsnorm.float(), self.ref_wis.transpose(0, 1).float())
+ return (
+ q_out.to(torch.bfloat16),
+ ki_out.to(torch.bfloat16),
+ idx_scores_out.to(torch.bfloat16),
+ )
+
+ def tilert_forward(
+ self,
+ x: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Run decoupled RMSNorm + ProjXWqaki + ProjXWis via TileRT CUDA kernels."""
+ rmsnorm_quant(
+ x.to(torch.bfloat16),
+ self.tilert_norm_gamma,
+ self.x_rmsnorm_out,
+ self.x_rmsnorm_quant_out,
+ self.x_rmsnorm_quant_scale_out,
+ self.profile_logs,
+ model_arch=self.model_args.arch_name,
+ )
+ projx_wqaki(
+ self.x_rmsnorm_quant_out,
+ self.x_rmsnorm_quant_scale_out,
+ self.tilert_wqakis,
+ self.q_out,
+ self.ki_out,
+ self.profile_logs,
+ self.compute_kernel_type,
+ model_arch=self.model_args.arch_name,
+ )
+ wis_compute_kernel_type = "bf16"
+ projx_wis(
+ self.x_rmsnorm_out,
+ self.tilert_wis,
+ self.idx_scores_out,
+ wis_compute_kernel_type,
+ self.profile_logs,
+ model_arch=self.model_args.arch_name,
+ )
+
+ return self.q_out, self.ki_out, self.idx_scores_out
+
+ def __call__(
+ self,
+ x: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ return self.golden_forward(x)
diff --git a/tilert/models/deepseek_v3_2/ops/rmsnorm_projx_wqkva.py b/tilert/models/deepseek_v3_2/ops/rmsnorm_projx_wqkva.py
new file mode 100644
index 0000000..8e58a24
--- /dev/null
+++ b/tilert/models/deepseek_v3_2/ops/rmsnorm_projx_wqkva.py
@@ -0,0 +1,516 @@
+"""RMSNormProjxWqkva operation module."""
+
+from enum import Enum
+
+import torch
+
+from tilert.models.base import TileRTModule, TilertWeightsConverter
+from tilert.models.common import weight_dequant
+from tilert.models.deepseek_v3_2.model_args import ModelArgs
+from tilert.utils import get_profile_log_tensor
+
+__all__ = [
+ "RMSNormProjxWqkva",
+ "RMSNormProjxWqkvaAlgorithm",
+]
+
+
+class RMSNormProjQKVAFP8MMAWeightsConverter:
+ """Weight converter: pack FP8 weights for the FP8 MMA kernel."""
+
+ HIDDEN_DIM = 6144
+ Q_LORA_RANK = 2048
+ KV_LORA_RANK = 512
+ QK_ROPE_HEAD_DIM = 64
+ TOTAL_ROWS = Q_LORA_RANK + KV_LORA_RANK + QK_ROPE_HEAD_DIM
+ ROWS_PER_CTA = 32
+ NUM_CTAS = TOTAL_ROWS // ROWS_PER_CTA
+ COLS_PER_PAGE = 1024
+ NUM_PAGES = HIDDEN_DIM // COLS_PER_PAGE
+ SCALES_PER_PAGE = COLS_PER_PAGE // 128
+ BLOCK_SIZE = 128
+
+ MAT_BYTES = ROWS_PER_CTA * COLS_PER_PAGE
+ SCALE_OFFSET = MAT_BYTES
+ PAGE_BYTES = ((MAT_BYTES + 128 + 127) // 128) * 128
+
+ @staticmethod
+ def _swizzle_mma_16x32(mat_in: torch.Tensor) -> torch.Tensor:
+ """Swizzle a [*, 16, 32] tile for the MMA kernel."""
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 32
+ pre_shape = mat_in.shape[:-2]
+ mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 4).transpose(-4, -3).transpose(-5, -4)
+ return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 4).transpose(-3, -2)
+
+ @staticmethod
+ def convert_to_fp8_mma_gemv(
+ wq_a: torch.Tensor,
+ wq_a_scale: torch.Tensor,
+ wkv_a: torch.Tensor,
+ wkv_a_scale: torch.Tensor,
+ w_pe: torch.Tensor,
+ w_pe_scale: torch.Tensor,
+ attn_norm_weight: torch.Tensor,
+ *,
+ hidden_dim: int = 6144,
+ q_lora_rank: int = 2048,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Pack FP8 weights for the FP8 MMA kernel.
+
+ Args:
+ hidden_dim: Model hidden dimension.
+ q_lora_rank: Q projection rank.
+ """
+ C = RMSNormProjQKVAFP8MMAWeightsConverter
+ block_size = C.BLOCK_SIZE
+ kv_lora_rank = C.KV_LORA_RANK
+ qk_rope_head_dim = C.QK_ROPE_HEAD_DIM
+
+ expected = q_lora_rank * hidden_dim
+ assert wq_a.numel() == expected, f"wq_a numel {wq_a.numel()} != expected {expected}"
+ expected = kv_lora_rank * hidden_dim
+ assert wkv_a.numel() == expected, f"wkv_a numel {wkv_a.numel()} != expected {expected}"
+ expected = qk_rope_head_dim * hidden_dim
+ assert w_pe.numel() == expected, f"w_pe numel {w_pe.numel()} != expected {expected}"
+
+ total_rows = q_lora_rank + kv_lora_rank + qk_rope_head_dim
+ num_ctas = total_rows // C.ROWS_PER_CTA
+ num_pages = hidden_dim // C.COLS_PER_PAGE
+
+ wq_a_f = weight_dequant(wq_a.reshape(q_lora_rank, hidden_dim), wq_a_scale)
+ wkv_a_f = weight_dequant(wkv_a.reshape(kv_lora_rank, hidden_dim), wkv_a_scale)
+ w_pe_f = weight_dequant(w_pe.reshape(qk_rope_head_dim, hidden_dim), w_pe_scale)
+ w_float = torch.cat([wq_a_f, wkv_a_f, w_pe_f], dim=0)
+
+ w_blocks = w_float.reshape(total_rows, hidden_dim // block_size, block_size)
+ col_max = w_blocks.abs().amax(dim=(0, 2))
+ fp8_max = torch.finfo(torch.float8_e4m3fn).max
+ w_scales = (col_max / fp8_max).clamp(min=1e-12)
+
+ scales_expanded = w_scales.repeat_interleave(block_size)
+ w_scaled = w_float / scales_expanded.unsqueeze(0)
+ w_fp8 = w_scaled.to(torch.float8_e4m3fn)
+
+ assert C.MAT_BYTES == C.SCALE_OFFSET, "Layout mismatch: scales must follow mat"
+ assert block_size == C.COLS_PER_PAGE // C.SCALES_PER_PAGE, "Block size mismatch"
+ assert w_scales.numel() == num_pages * C.SCALES_PER_PAGE, "Scale count mismatch"
+
+ w_bytes = w_fp8.view(torch.uint8)
+ num_tiles = C.COLS_PER_PAGE // 32
+
+ mat = w_bytes.reshape(num_ctas, C.ROWS_PER_CTA, num_pages, C.COLS_PER_PAGE)
+ mat = mat.transpose(1, 2)
+
+ mat = mat.reshape(num_ctas, num_pages, 2, 16, num_tiles, 32)
+ mat = mat.transpose(3, 4)
+ mat = C._swizzle_mma_16x32(mat)
+ mat = mat.contiguous().reshape(num_ctas, num_pages, C.MAT_BYTES)
+
+ scales_f32 = w_scales.reshape(num_pages, C.SCALES_PER_PAGE).to(torch.float32).contiguous()
+ scales_bytes = scales_f32.view(torch.uint8)
+ scales_bytes = scales_bytes.unsqueeze(0).expand(num_ctas, -1, -1)
+
+ pad_size = C.PAGE_BYTES - C.MAT_BYTES - C.SCALES_PER_PAGE * 4
+ padding = torch.zeros(num_ctas, num_pages, pad_size, dtype=torch.uint8, device=w_fp8.device)
+
+ packed = torch.cat([mat, scales_bytes, padding], dim=-1)
+ packed = packed.contiguous().reshape(-1)
+
+ return packed.view(torch.float8_e4m3fn), attn_norm_weight.clone()
+
+
+class RMSNormProjQKVAFP16MMAWeightsConverter:
+ """Weight converter: pack FP16 weights for the FP16 MMA kernel."""
+
+ KV_LORA_RANK = 512
+ QK_ROPE_HEAD_DIM = 64
+ ROWS_PER_CTA = 32
+ COLS_PER_PAGE = 512
+ BLOCK_SIZE = 128
+
+ @staticmethod
+ def _swizzle_mma_16x16(mat_in: torch.Tensor) -> torch.Tensor:
+ """Swizzle a [*, 16, 16] tile for the MMA kernel."""
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 16
+ pre_shape = mat_in.shape[:-2]
+ mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 2).transpose(-4, -3).transpose(-5, -4)
+ return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 2).transpose(-3, -2)
+
+ @staticmethod
+ def convert_to_fp16_mma_gemv(
+ wq_a: torch.Tensor,
+ wq_a_scale: torch.Tensor,
+ wkv_a: torch.Tensor,
+ wkv_a_scale: torch.Tensor,
+ w_pe: torch.Tensor,
+ w_pe_scale: torch.Tensor,
+ attn_norm_weight: torch.Tensor,
+ *,
+ hidden_dim: int = 6144,
+ q_lora_rank: int = 2048,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Pack weights into the FP16 MMA layout."""
+ C = RMSNormProjQKVAFP16MMAWeightsConverter
+ kv_lora_rank = C.KV_LORA_RANK
+ qk_rope_head_dim = C.QK_ROPE_HEAD_DIM
+ cols_per_page = C.COLS_PER_PAGE
+ rows_per_cta = C.ROWS_PER_CTA
+
+ total_rows = q_lora_rank + kv_lora_rank + qk_rope_head_dim
+ num_ctas = total_rows // rows_per_cta
+ num_pages = hidden_dim // cols_per_page
+ num_k_tiles = cols_per_page // 16
+
+ wq_a_f = weight_dequant(wq_a.reshape(q_lora_rank, hidden_dim), wq_a_scale)
+ wkv_a_f = weight_dequant(wkv_a.reshape(kv_lora_rank, hidden_dim), wkv_a_scale)
+ w_pe_f = weight_dequant(w_pe.reshape(qk_rope_head_dim, hidden_dim), w_pe_scale)
+ w_float = torch.cat([wq_a_f, wkv_a_f, w_pe_f], dim=0)
+
+ w_fp16 = w_float.to(torch.float16)
+
+ mat = w_fp16.reshape(num_ctas, rows_per_cta, num_pages, cols_per_page)
+ mat = mat.transpose(1, 2)
+
+ mat = mat.reshape(num_ctas, num_pages, 2, 16, num_k_tiles, 16)
+ mat = mat.transpose(3, 4)
+ mat = C._swizzle_mma_16x16(mat)
+ mat = mat.contiguous()
+
+ mat_bytes = mat.view(torch.uint8).reshape(num_ctas, num_pages, -1)
+ packed = mat_bytes.contiguous().reshape(-1)
+
+ return packed.view(torch.float16), attn_norm_weight.clone()
+
+
+class RMSNormProjxWqkvaAlgorithm(Enum):
+ """RMSNormProjxWqkva algorithm."""
+
+ DECOUPLED = "decoupled"
+
+
+class RMSNormProjxWqkvaWeightsConverter(TilertWeightsConverter):
+ """Dispatch weight converter for RMSNormProjxWqkva."""
+
+ def __init__(self, model_args: ModelArgs, num_devices: int):
+ super().__init__(model_args, num_devices)
+
+ def convert_to_fp8_mma_gemv(
+ self, weights: list[torch.Tensor]
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Convert tilert weights list to FP8 MMA kernel-ready format.
+
+ Args:
+ weights: [gamma, wq_a, wq_a_scale, wkv_a, wkv_a_scale, w_pe, w_pe_scale]
+ """
+ gamma, wq_a, wq_a_scale, wkv_a, wkv_a_scale, w_pe, w_pe_scale = weights
+ return RMSNormProjQKVAFP8MMAWeightsConverter.convert_to_fp8_mma_gemv(
+ wq_a,
+ wq_a_scale,
+ wkv_a,
+ wkv_a_scale,
+ w_pe,
+ w_pe_scale,
+ gamma,
+ hidden_dim=self.model_args.dim,
+ q_lora_rank=self.model_args.q_lora_rank,
+ )
+
+ def convert_to_fp16_mma_gemv(
+ self, weights: list[torch.Tensor]
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Convert tilert weights list to FP16 MMA kernel-ready format.
+
+ Args:
+ weights: [gamma, wq_a, wq_a_scale, wkv_a, wkv_a_scale, w_pe, w_pe_scale]
+ """
+ gamma, wq_a, wq_a_scale, wkv_a, wkv_a_scale, w_pe, w_pe_scale = weights
+ return RMSNormProjQKVAFP16MMAWeightsConverter.convert_to_fp16_mma_gemv(
+ wq_a,
+ wq_a_scale,
+ wkv_a,
+ wkv_a_scale,
+ w_pe,
+ w_pe_scale,
+ gamma,
+ hidden_dim=self.model_args.dim,
+ q_lora_rank=self.model_args.q_lora_rank,
+ )
+
+
+class RMSNormProjxWqkvaRefWeightsAlias:
+ """Reference weight aliases for RMSNormProjxWqkva."""
+
+ x_rmsnorm_gamma = "input_layernorm.weight"
+ q_a_weights = "self_attn.q_a_proj.weight"
+ q_a_scales = "self_attn.q_a_proj.weight_scale_inv"
+ kv_a_weights = "self_attn.kv_a_proj_with_mqa.weight"
+ kv_a_scales = "self_attn.kv_a_proj_with_mqa.weight_scale_inv"
+
+ @property
+ def ref_tensor_alias(self) -> list[str]:
+ return [
+ self.x_rmsnorm_gamma,
+ self.q_a_weights,
+ self.q_a_scales,
+ self.kv_a_weights,
+ self.kv_a_scales,
+ ]
+
+ def __call__(self) -> list[str]:
+ return self.ref_tensor_alias
+
+
+class RMSNormProjxWqkvaTilertWeightsAlias:
+ """Tilert weight aliases for RMSNormProjxWqkva."""
+
+ x_rmsnorm_gamma = "x_rmsnorm_gamma"
+ q_a_weights = "q_a_weights"
+ q_a_scales = "q_a_scales"
+ kv_a_weights = "kv_a_weights"
+ kv_a_scales = "kv_a_scales"
+ w_pe_weights = "w_pe_weights"
+ w_pe_scales = "w_pe_scales"
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return [
+ self.x_rmsnorm_gamma,
+ self.q_a_weights,
+ self.q_a_scales,
+ self.kv_a_weights,
+ self.kv_a_scales,
+ self.w_pe_weights,
+ self.w_pe_scales,
+ ]
+
+ def __call__(self) -> list[str]:
+ return self.tilert_tensor_alias
+
+
+class RMSNormProjxWqkva(TileRTModule):
+ """Fused RMSNorm + GEMV(W_q_a, W_kv_a, W_pe)."""
+
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [RMSNormProjxWqkvaAlgorithm.DECOUPLED],
+ "glm_5": [RMSNormProjxWqkvaAlgorithm.DECOUPLED],
+ }
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ num_devices: int,
+ device_id: int,
+ ref_weights_alias: RMSNormProjxWqkvaRefWeightsAlias | None = None,
+ ):
+ super().__init__(
+ self.__class__.__name__,
+ model_args=model_args,
+ num_devices=num_devices,
+ device_id=device_id,
+ )
+
+ self.tilert_weights_alias = RMSNormProjxWqkvaTilertWeightsAlias()
+ self.ref_weights_alias = (
+ ref_weights_alias
+ if ref_weights_alias is not None
+ else RMSNormProjxWqkvaRefWeightsAlias()
+ )
+
+ self.dim = self.model_args.dim
+ self.q_lora_rank = self.model_args.q_lora_rank
+ self.kv_lora_rank = self.model_args.kv_lora_rank
+ self.qk_rope_head_dim = self.model_args.qk_rope_head_dim
+ self.block_size = self.model_args.block_size
+ self.eps = self.model_args.eps
+ self.algorithm = RMSNormProjxWqkvaAlgorithm.DECOUPLED
+
+ self.ref_norm_gamma: torch.Tensor | None = None
+ self.ref_wq_a: torch.Tensor | None = None
+ self.ref_wkv_a: torch.Tensor | None = None
+ self.ref_w_pe: torch.Tensor | None = None
+
+ self.tilert_norm_gamma: torch.Tensor | None = None
+ self.tilert_wqkva: torch.Tensor | None = None
+ self.tilert_wqkva_scales = torch.zeros((1, 1), dtype=torch.bfloat16)
+
+ self.x_rmsnorm_out: torch.Tensor | None = None
+ self.x_rmsnorm_quant_out: torch.Tensor | None = None
+ self.x_rmsnorm_quant_scale_out: torch.Tensor | None = None
+
+ self.q_out: torch.Tensor | None = None
+ self.kv_out: torch.Tensor | None = None
+ self.pe_cache_out: torch.Tensor | None = None
+ self.cur_pos: torch.Tensor | None = None
+ self.profile_logs: torch.Tensor | None = None
+ self.is_init = False
+
+ self.tilert_tensor_alias: list[str] = [
+ "x_rmsnorm_gamma",
+ "qkva_weights",
+ "qkva_scales",
+ ]
+
+ def get_weights_list(self) -> list[torch.Tensor]:
+ return [self.tilert_norm_gamma, self.tilert_wqkva, self.tilert_wqkva_scales]
+
+ def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ """Repeat weights for device sharding."""
+ input_layernorm_weight = (
+ weights_map[self.ref_weights_alias.x_rmsnorm_gamma][None, ...]
+ .float()
+ .repeat(self.num_devices, 1)
+ )
+ q_a_proj_weight = weights_map[self.ref_weights_alias.q_a_weights][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ q_a_proj_weight_scale = weights_map[self.ref_weights_alias.q_a_scales][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ kv_a_mqa = weights_map[self.ref_weights_alias.kv_a_weights]
+ kv_a_proj_weight = kv_a_mqa[: self.kv_lora_rank, :][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ w_pe_weight = kv_a_mqa[self.kv_lora_rank :, :][None, ...].repeat(self.num_devices, 1, 1)
+ kv_a_mqa_scale = weights_map[self.ref_weights_alias.kv_a_scales]
+ kv_scale_rows = (self.kv_lora_rank + self.block_size - 1) // self.block_size
+ kv_a_proj_weight_scale = kv_a_mqa_scale[:kv_scale_rows, :][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ w_pe_weight_scale = kv_a_mqa_scale[kv_scale_rows:, :][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ return {
+ self.tilert_weights_alias.x_rmsnorm_gamma: input_layernorm_weight,
+ self.tilert_weights_alias.q_a_weights: q_a_proj_weight,
+ self.tilert_weights_alias.q_a_scales: q_a_proj_weight_scale,
+ self.tilert_weights_alias.kv_a_weights: kv_a_proj_weight,
+ self.tilert_weights_alias.kv_a_scales: kv_a_proj_weight_scale,
+ self.tilert_weights_alias.w_pe_weights: w_pe_weight,
+ self.tilert_weights_alias.w_pe_scales: w_pe_weight_scale,
+ }
+
+ def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ aliases = self.ref_weights_alias()
+ self.ref_norm_gamma = state_dict[aliases[0]]
+ self.ref_wq_a = weight_dequant(state_dict[aliases[1]], state_dict[aliases[2]])
+ kv_a_mqa = weight_dequant(state_dict[aliases[3]], state_dict[aliases[4]])
+ self.ref_wkv_a = kv_a_mqa[: self.kv_lora_rank, :]
+ self.ref_w_pe = kv_a_mqa[self.kv_lora_rank :, :]
+
+ assert self.ref_norm_gamma.shape[-1] == self.dim
+ assert self.ref_wq_a.shape == (self.q_lora_rank, self.dim)
+ assert self.ref_wkv_a.shape == (self.kv_lora_rank, self.dim)
+ assert self.ref_w_pe.shape == (self.qk_rope_head_dim, self.dim)
+
+ def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ tilert_aliases = self.tilert_weights_alias()
+ weights_list = [state_dict[alias] for alias in tilert_aliases]
+ converter = RMSNormProjxWqkvaWeightsConverter(self.model_args, self.num_devices)
+ self.tilert_wqkva, self.tilert_norm_gamma = converter.convert_to_fp8_mma_gemv(weights_list)
+ self.tilert_wqkva_scales = torch.zeros((1,), dtype=torch.float32)
+
+ def init_tilert_vars(self, batch_size: int, seq_len: int, max_len: int = 128) -> None:
+ self.q_out = torch.zeros((batch_size, seq_len, self.q_lora_rank), dtype=torch.bfloat16)
+ self.kv_out = torch.zeros((batch_size, seq_len, self.kv_lora_rank), dtype=torch.bfloat16)
+ self.pe_cache_out = torch.zeros(
+ (batch_size, max_len, self.qk_rope_head_dim), dtype=torch.bfloat16
+ )
+ self.cur_pos = torch.zeros((1,), dtype=torch.int32)
+ self.x_rmsnorm_out = torch.zeros((batch_size, seq_len, self.dim), dtype=torch.bfloat16)
+ self.x_rmsnorm_quant_out = torch.zeros(
+ (batch_size, seq_len, self.dim), dtype=torch.float8_e4m3fn
+ )
+ self.x_rmsnorm_quant_scale_out = torch.zeros(
+ (batch_size, seq_len, self.dim // self.block_size), dtype=torch.float32
+ )
+ self.profile_logs = get_profile_log_tensor()
+ self.is_init = True
+
+ def init_random_weights(self) -> None:
+ bs = self.block_size
+ dim_scale_dim = self.dim // bs
+ q_scale_dim = (self.q_lora_rank + bs - 1) // bs
+ kv_mqa_rows = self.kv_lora_rank + self.qk_rope_head_dim
+ kv_mqa_scale_dim = (kv_mqa_rows + bs - 1) // bs
+ scale_dtype = torch.bfloat16
+
+ tensor_list = [
+ torch.randn(self.dim, dtype=torch.float32),
+ torch.randn(self.q_lora_rank, self.dim, dtype=torch.bfloat16).to(torch.float8_e4m3fn),
+ torch.randn(q_scale_dim, dim_scale_dim, dtype=scale_dtype),
+ torch.randn(kv_mqa_rows, self.dim, dtype=torch.bfloat16).to(torch.float8_e4m3fn),
+ torch.randn(kv_mqa_scale_dim, dim_scale_dim, dtype=scale_dtype),
+ ]
+ ref_state_dict = dict(zip(self.ref_weights_alias(), tensor_list))
+ self.init_reference_weights(ref_state_dict)
+ self.init_tilert_weights(
+ {_k: _v[self.device_id] for _k, _v in self.device_sharding(ref_state_dict).items()}
+ )
+
+ def golden_forward(
+ self,
+ x: torch.Tensor,
+ cur_pos: int = 0, # noqa: U100
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Pure PyTorch reference: RMSNorm -> q, kv, pe."""
+ assert self.ref_norm_gamma is not None
+ assert self.ref_wq_a is not None
+ assert self.ref_wkv_a is not None
+ assert self.ref_w_pe is not None
+
+ x_rmsnorm = torch.nn.functional.rms_norm(
+ x.float(), [x.size(-1)], self.ref_norm_gamma, self.eps
+ )
+ q_out = torch.matmul(x_rmsnorm.float(), self.ref_wq_a.transpose(0, 1).float())
+ kv_out = torch.matmul(x_rmsnorm.float(), self.ref_wkv_a.transpose(0, 1).float())
+ pe_out = torch.matmul(x_rmsnorm.float(), self.ref_w_pe.transpose(0, 1).float())
+ return (
+ q_out.to(torch.bfloat16),
+ kv_out.to(torch.bfloat16),
+ pe_out.to(torch.bfloat16),
+ )
+
+ def tilert_forward(
+ self,
+ x: torch.Tensor,
+ cur_pos: int = 0,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Run RMSNorm + 3-way GEMV via TileRT CUDA kernel (DECOUPLED)."""
+ assert self.cur_pos is not None
+ assert self.pe_cache_out is not None
+ self.cur_pos.fill_(cur_pos)
+
+ from tilert.models.deepseek_v3_2.ops.projx_wqkva import projx_wqkva as _projx_wqkva
+ from tilert.models.deepseek_v3_2.ops.rmsnorm_quant import rmsnorm_quant as _rmsnorm_quant
+
+ _rmsnorm_quant(
+ x.to(torch.bfloat16),
+ self.tilert_norm_gamma,
+ self.x_rmsnorm_out,
+ self.x_rmsnorm_quant_out,
+ self.x_rmsnorm_quant_scale_out,
+ self.profile_logs,
+ model_arch=self.model_args.arch_name,
+ )
+ _projx_wqkva(
+ self.x_rmsnorm_quant_out,
+ self.x_rmsnorm_quant_scale_out,
+ self.tilert_wqkva,
+ self.cur_pos,
+ self.q_out,
+ self.kv_out,
+ self.pe_cache_out,
+ self.profile_logs,
+ model_arch=self.model_args.arch_name,
+ )
+
+ seq_len = x.size(-2)
+ pe_at_pos = self.pe_cache_out[:, cur_pos : cur_pos + seq_len, :]
+ return self.q_out, self.kv_out, pe_at_pos
+
+ def __call__(
+ self,
+ x: torch.Tensor,
+ cur_pos: int = 0,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ return self.golden_forward(x, cur_pos)
diff --git a/python/models/deepseek_v3_2/ops/rmsnorm_quant.py b/tilert/models/deepseek_v3_2/ops/rmsnorm_quant.py
similarity index 56%
rename from python/models/deepseek_v3_2/ops/rmsnorm_quant.py
rename to tilert/models/deepseek_v3_2/ops/rmsnorm_quant.py
index 770db02..1d399c5 100644
--- a/python/models/deepseek_v3_2/ops/rmsnorm_quant.py
+++ b/tilert/models/deepseek_v3_2/ops/rmsnorm_quant.py
@@ -1,8 +1,4 @@
-"""RMSNormQuant operation module.
-
-Unified for deepseek_v3_2 (dim=7168) and glm_5 (dim=6144).
-Dispatches by hidden_in.shape[-1]: 7168 -> rmsnorm_*_op, 6144 -> rmsnorm_*_glm5_op.
-"""
+"""RMSNormQuant operation module."""
from __future__ import annotations
@@ -27,14 +23,13 @@ def rmsnorm_quant(
quant_hidden_out: torch.Tensor | None = None,
quant_hidden_scale_out: torch.Tensor | None = None,
profile_logs: torch.Tensor | None = None,
+ compute_kernel_type: str = "general",
+ *,
+ model_arch: str,
) -> None:
"""
Rmsnorm with optional activation quantization.
- Unified for deepseek_v3_2 (dim=7168) and glm_5 (dim=6144). Dispatches by
- hidden_in.shape[-1]: 7168 -> rmsnorm_op / rmsnorm_quant_op,
- 6144 -> rmsnorm_glm5_op / rmsnorm_quant_glm5_op.
-
Args:
hidden_in: Input tensor (..., dim).
gamma_in: RMSNorm gamma (dim,).
@@ -43,31 +38,27 @@ def rmsnorm_quant(
quant_hidden_scale_out: Optional quant scale (..., dim // block_size). If None, no quant.
profile_logs: Optional profile logs tensor.
"""
- dim = hidden_in.shape[-1]
- if dim == DIM_GLM_5:
- glm5_flag = "_glm5"
- elif dim == DIM_DEEPSEEK_V3_2:
- glm5_flag = ""
- else:
- raise ValueError(
- f"Unsupported hidden_in.shape[-1]: {dim}. "
- f"rmsnorm_quant supports {DIM_DEEPSEEK_V3_2} (deepseek_v3_2) or {DIM_GLM_5} (glm_5)."
- )
+ if profile_logs is None:
+ raise ValueError("profile_logs is required when calling rmsnorm_quant.")
+
if quant_hidden_out is None or quant_hidden_scale_out is None:
- quant_flag = ""
- quant_args = [hidden_in, gamma_in, hidden_out, profile_logs]
+ torch.ops.tilert.rmsnorm_op(
+ hidden_in,
+ gamma_in,
+ hidden_out,
+ model_arch,
+ compute_kernel_type,
+ profile_logs,
+ )
else:
- quant_flag = "_quant"
- quant_args = [
+ torch.ops.tilert.rmsnorm_quant_op(
hidden_in,
gamma_in,
hidden_out,
quant_hidden_out,
quant_hidden_scale_out,
+ model_arch,
+ compute_kernel_type,
profile_logs,
- ]
- if profile_logs is None:
- raise ValueError("profile_logs is required when calling rmsnorm_quant.")
- func_name = f"rmsnorm{quant_flag}{glm5_flag}_op"
- func_call = getattr(torch.ops.tilert, func_name)
- func_call(*quant_args)
+ torch.empty(0, dtype=torch.int64, device=hidden_in.device),
+ )
diff --git a/python/models/deepseek_v3_2/ops/rmsnorm_up_gate_silu.py b/tilert/models/deepseek_v3_2/ops/rmsnorm_up_gate_silu.py
similarity index 93%
rename from python/models/deepseek_v3_2/ops/rmsnorm_up_gate_silu.py
rename to tilert/models/deepseek_v3_2/ops/rmsnorm_up_gate_silu.py
index e2f5c59..db991da 100644
--- a/python/models/deepseek_v3_2/ops/rmsnorm_up_gate_silu.py
+++ b/tilert/models/deepseek_v3_2/ops/rmsnorm_up_gate_silu.py
@@ -13,7 +13,6 @@
ExpertSelectUpGateSiLU,
ExpertSelectUpGateSiLUWeightsConverter,
)
-from tilert.profiler.utils import parse_profile_log_tensor
from tilert.utils import get_profile_log_tensor
__all__ = [
@@ -30,6 +29,7 @@ def rmsnorm_up_gate_silu(
weights_in: torch.Tensor,
hidden_out: torch.Tensor,
profile_logs: torch.Tensor,
+ model_arch: str,
compute_kernel_type: str = "fp8mma",
) -> None:
"""rmsnorm_up_gate_silu operation."""
@@ -38,8 +38,9 @@ def rmsnorm_up_gate_silu(
gamma_in,
weights_in,
hidden_out,
- profile_logs,
+ model_arch,
compute_kernel_type,
+ profile_logs,
)
@@ -48,6 +49,7 @@ class RMSNormUpGateSiLUAlgorithm(Enum):
FP8MMA = "fp8mma"
FP16MMA = "fp16mma"
+ BF16MMA = "bf16mma"
RMSNormUpGateSiLUWeightsConverter = ExpertSelectUpGateSiLUWeightsConverter
@@ -81,6 +83,15 @@ def __call__(self) -> list[str]:
class RMSNormUpGateSiLU(TileRTModule):
"""RMSNormUpGateSiLU module"""
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [
+ RMSNormUpGateSiLUAlgorithm.FP8MMA,
+ RMSNormUpGateSiLUAlgorithm.FP16MMA,
+ RMSNormUpGateSiLUAlgorithm.BF16MMA,
+ ],
+ "glm_5": [RMSNormUpGateSiLUAlgorithm.FP8MMA, RMSNormUpGateSiLUAlgorithm.FP16MMA],
+ }
+
def __init__(
self,
model_args: ModelArgs,
@@ -102,38 +113,31 @@ def __init__(
self.moe_inter_dim = self.model_args.moe_inter_dim
self.moe_inter_dim_per_device = self.moe_inter_dim // self.num_devices
self.inter_dim_per_device = self.inter_dim // self.num_devices
- # effective number of experts
self.n_experts = self.inter_dim_per_device // self.moe_inter_dim_per_device
self.eps = self.model_args.eps
self.block_size = self.model_args.block_size
self.algorithm = algorithm
- # reference weights
self.ref_norm_gamma: torch.Tensor | None = None
self.ref_gate: torch.Tensor | None = None
self.ref_up: torch.Tensor | None = None
- # tilert weights
self.tilert_norm_gamma: torch.Tensor | None = None
self.tilert_weights: torch.Tensor | None = None
- # for compatibility, to be removed in the future
self.tilert_scales = torch.zeros(
9, 4, 64, dtype=torch.bfloat16, device=torch.device("cuda")
)
- # tilert vars
self.hidden_out: torch.Tensor | None = None
self.profile_logs: torch.Tensor | None = None
self.is_init = False
- # tilert_funcs
self.rmsnorm_up_gate_silu_func = rmsnorm_up_gate_silu
self.tilert_weights_alias = RMSNormUpGateSiLUTilertWeightsAlias()
- # reference tensor aliases
self.ref_tensor_alias: list[str] = [
"post_attention_layernorm.weight",
"mlp.gate_proj.weight",
@@ -158,7 +162,7 @@ def get_weights_list(self) -> list[torch.Tensor]:
def device_sharding(
self,
weights_dict: dict[str, torch.Tensor],
- key_prefix: str, # e.g. model.layers.{layer_id}.mlp
+ key_prefix: str,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Device sharding.
@@ -176,7 +180,6 @@ def device_sharding(
elif key_prefix == "mlp":
rmsnorm_gamma_key = "post_attention_layernorm.weight"
rmsnorm_gamma = weights_dict[rmsnorm_gamma_key]
- # repeat rmsnorm_gamma for each device
rmsnorm_gamma = rmsnorm_gamma[None, :].repeat(self.num_devices, 1)
gate_weights, gate_scales, up_weights, up_scales = (
@@ -186,7 +189,6 @@ def device_sharding(
self.num_devices,
)
)
- # Transpose split so to match the old convertcode
gate_weights = gate_weights.reshape(self.n_experts, self.num_devices, -1, self.dim)
gate_weights = gate_weights.transpose(0, 1)
gate_scales = gate_scales.reshape(
@@ -210,7 +212,7 @@ def device_sharding(
def init_reference_weights(
self,
state_dict: dict[str, torch.Tensor],
- key_prefix: str, # e.g. model.layers.{layer_id}.mlp
+ key_prefix: str,
device_id: int = 0,
) -> None:
"""
@@ -259,7 +261,6 @@ def init_tilert_vars(self, batch_size: int, seq_len: int, dev_id: int = 0) -> No
batch_size: Batch size.
seq_len: Sequence length.
"""
- # tilert vars
self.hidden_out = torch.zeros(
(
batch_size,
@@ -274,13 +275,15 @@ def init_tilert_vars(self, batch_size: int, seq_len: int, dev_id: int = 0) -> No
self.profile_logs = get_profile_log_tensor(device=f"cuda:{dev_id}")
self.is_init = True
- def init_random_weights(self, dev_id: int = 0) -> None:
+ def init_random_weights(self, dev_id: int | None = None) -> None:
"""
Initialize the random weights.
Returns:
None
"""
+ if dev_id is None:
+ dev_id = self.device_id
gamma = torch.randn(self.dim, dtype=torch.float32, device=f"cuda:{dev_id}")
gate_weights = torch.randn(
self.inter_dim, self.dim, dtype=torch.bfloat16, device=f"cuda:{dev_id}"
@@ -326,7 +329,6 @@ def golden_forward(
)
hidden_out_list = []
for s in range(seq_len):
- # ref up-gate silu
hidden_out_w1_list = []
hidden_out_w3_list = []
@@ -356,12 +358,9 @@ def tilert_forward(
self.tilert_weights,
self.hidden_out,
self.profile_logs,
- self.algorithm.value,
+ model_arch=self.model_args.arch_name,
+ compute_kernel_type=self.algorithm.value,
)
- if self.flag_enable_profiling_log:
- parse_profile_log_tensor(
- self.profile_logs, self.get_profile_log_path(), [(self.op_name, 0.0)]
- )
return self.hidden_out
def __call__(
diff --git a/tilert/models/deepseek_v3_2/ops/rotate.py b/tilert/models/deepseek_v3_2/ops/rotate.py
new file mode 100644
index 0000000..19c2746
--- /dev/null
+++ b/tilert/models/deepseek_v3_2/ops/rotate.py
@@ -0,0 +1,226 @@
+"""Rotate(hadamard transform) operation module."""
+
+from dataclasses import dataclass
+from enum import Enum
+
+import torch
+import torch.nn.functional as F
+
+from tilert.models.base import TileRTModule
+from tilert.models.deepseek_v3_2.model_args import ModelArgs
+from tilert.models.utils import apply_rotary_emb
+from tilert.utils import get_profile_log_tensor
+
+try:
+ from fast_hadamard_transform import hadamard_transform
+
+ def rotate_activation(x: torch.Tensor) -> torch.Tensor:
+ assert x.dtype == torch.bfloat16
+ hidden_size = x.size(-1)
+ return hadamard_transform(x, scale=hidden_size**-0.5)
+
+except ImportError:
+ print(
+ "Cannot import hadamard_transform, fallback to scipy.linalg.hadamard."
+ "please install fast_hadamard_transform for correct performance."
+ )
+ import math
+
+ from scipy.linalg import hadamard
+
+ def hadamard_transform_ref(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
+ x_shape = x.shape
+ dim = x.shape[-1]
+ x = x.reshape(-1, dim)
+ log_dim = math.ceil(math.log2(dim))
+ dim_padded = 2**log_dim
+ if dim != dim_padded:
+ x = F.pad(x, (0, dim_padded - dim))
+ out = F.linear(
+ x,
+ torch.tensor(hadamard(dim_padded, dtype=float), dtype=x.dtype, device=x.device),
+ )
+ out = out * scale
+ return out[..., :dim].reshape(*x_shape)
+
+ def rotate_activation(x: torch.Tensor) -> torch.Tensor:
+ assert x.dtype == torch.bfloat16
+ hidden_size = x.size(-1)
+ return hadamard_transform_ref(x, scale=hidden_size**-0.5)
+
+
+__all__ = [
+ "rotate",
+ "rotate_activation",
+ "Rotate",
+ "RotateRefWeightsAlias",
+ "RotateTilertWeightsAlias",
+]
+
+
+def rotate(
+ input_raw: torch.Tensor,
+ output_raw: torch.Tensor,
+ freqs_cis_raw: torch.Tensor,
+ profile_logs: torch.Tensor,
+ model_arch: str,
+ compute_kernel_type: str = "general",
+ kv_cache: torch.Tensor | None = None,
+ cur_pos: torch.Tensor | None = None,
+ cache_base: int = 0,
+ cache_stride: int = 0,
+ cache_compressed: bool = False,
+) -> None:
+ """
+ Rotate (hadamard transform) operation.
+
+ Args:
+ input_raw (torch.Tensor): The input tensor [..., head, 128].
+ output_raw (torch.Tensor): The output tensor where the result will be stored.
+ freqs_cis_raw (torch.Tensor): The frequency tensor.
+ profile_logs (torch.Tensor): Tensor for storing profiling logs.
+ model_arch: Model architecture string.
+ compute_kernel_type: Compute kernel type string.
+ kv_cache: Optional cache write target.
+ cur_pos: Optional [1] int32 tensor. Required when kv_cache is set.
+ cache_base: Base row index.
+ cache_stride: Stride. Must be > 0 when ``cache_compressed=True``.
+ cache_compressed: Cache write mode selector.
+
+ Returns:
+ None
+ """
+ torch.ops.tilert.rotate_op(
+ input_raw,
+ output_raw,
+ freqs_cis_raw,
+ model_arch,
+ compute_kernel_type,
+ profile_logs,
+ kv_cache,
+ cur_pos,
+ cache_base,
+ cache_stride,
+ cache_compressed,
+ )
+
+
+@dataclass
+class RotateRefWeightsAlias:
+ """Reference weights alias for Rotate (no weights)."""
+
+ @property
+ def ref_tensor_alias(self) -> list[str]:
+ return []
+
+ def __call__(self) -> list[str]:
+ return self.ref_tensor_alias
+
+
+@dataclass
+class RotateTilertWeightsAlias:
+ """TileRT weights alias for Rotate (no weights)."""
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return []
+
+ def __call__(self) -> list[str]:
+ return self.tilert_tensor_alias
+
+
+class RotateAlgorithm(Enum):
+ """Rotate algorithm."""
+
+ GENERAL = "general"
+
+
+class Rotate(TileRTModule):
+ """Rotate module: RoPE on first qk_rope_head_dim dims + hadamard transform.
+
+ Unified for deepseek_v3_2 (index_n_heads=64) and glm_5 (index_n_heads=32).
+ No weights; uses model_args for dimensions.
+ """
+
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [RotateAlgorithm.GENERAL],
+ "glm_5": [RotateAlgorithm.GENERAL],
+ }
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ num_devices: int = 1,
+ device_id: int = 0,
+ ref_weights_alias: RotateRefWeightsAlias | None = None,
+ ):
+ super().__init__(
+ self.__class__.__name__,
+ model_args=model_args,
+ num_devices=num_devices,
+ device_id=device_id,
+ )
+ self.tilert_weights_alias = RotateTilertWeightsAlias()
+ self.ref_weights_alias = (
+ ref_weights_alias if ref_weights_alias is not None else RotateRefWeightsAlias()
+ )
+
+ self.qk_rope_head_dim = model_args.qk_rope_head_dim
+ self.index_n_heads = model_args.index_n_heads
+ self.index_head_dim = model_args.index_head_dim
+
+ self.output: torch.Tensor | None = None
+ self.profile_logs: torch.Tensor | None = None
+
+ def get_weights_list(self) -> list[torch.Tensor]:
+ return []
+
+ def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ del weights_map
+ return {}
+
+ def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ del state_dict
+ pass
+
+ def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ del state_dict
+ pass
+
+ def init_random_weights(self) -> None:
+ pass
+
+ def init_tilert_vars(self, batch_size: int, seq_len: int) -> None:
+ self.output = torch.zeros(
+ (batch_size, seq_len, self.index_n_heads, self.index_head_dim),
+ dtype=torch.bfloat16,
+ )
+ self.profile_logs = get_profile_log_tensor()
+ self.is_init = True
+
+ def golden_forward(
+ self,
+ idx_q: torch.Tensor,
+ freqs_cis: torch.Tensor,
+ ) -> torch.Tensor:
+ q_pe_idx, q_nope_idx = torch.split(
+ idx_q,
+ [self.qk_rope_head_dim, self.index_head_dim - self.qk_rope_head_dim],
+ dim=-1,
+ )
+ q_pe_idx = apply_rotary_emb(q_pe_idx, freqs_cis, interleaved=False)
+ idx_q = torch.cat([q_pe_idx, q_nope_idx], dim=-1)
+ return rotate_activation(idx_q)
+
+ def tilert_forward(self, idx_q: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
+ assert self.output is not None
+ assert self.profile_logs is not None
+ freqs_cis_real = torch.view_as_real(freqs_cis).reshape(*freqs_cis.shape[:-1], -1)
+ rotate(
+ idx_q,
+ self.output,
+ freqs_cis_real,
+ self.profile_logs,
+ model_arch=self.model_args.arch_name,
+ )
+ return self.output
diff --git a/python/models/deepseek_v3_2/ops/sparse_index.py b/tilert/models/deepseek_v3_2/ops/sparse_index.py
similarity index 87%
rename from python/models/deepseek_v3_2/ops/sparse_index.py
rename to tilert/models/deepseek_v3_2/ops/sparse_index.py
index 0c21ce8..870855b 100644
--- a/python/models/deepseek_v3_2/ops/sparse_index.py
+++ b/tilert/models/deepseek_v3_2/ops/sparse_index.py
@@ -15,6 +15,9 @@ def sparse_index(
logits: torch.Tensor,
cur_pos: int,
profile_logs: torch.Tensor,
+ compute_kernel_type: str = "bf16",
+ *,
+ model_arch: str,
) -> None:
"""
Sparse index operation.
@@ -28,6 +31,8 @@ def sparse_index(
logits (torch.Tensor): The logits tensor.
cur_pos (int): The position of the first token.
profile_logs (torch.Tensor): Tensor for storing profiling logs.
+ compute_kernel_type (str): Kernel type ("bf16").
+ model_arch (str): Model architecture ("deepseek_v3_2").
Returns:
None
@@ -59,10 +64,9 @@ def sparse_index(
f"q={device}, kv={kv.device}, weights={weights.device}, "
f"logits={logits.device}, profile_logs={profile_logs.device}"
)
- if head == 64:
- torch.ops.tilert.sparse_index_op(q, kv, weights, logits, cur_pos, profile_logs)
- elif head == 32:
- torch.ops.tilert.sparse_index_glm5_op(q, kv, weights, logits, cur_pos, profile_logs)
+ torch.ops.tilert.sparse_index_op(
+ q, kv, weights, logits, cur_pos, model_arch, compute_kernel_type, profile_logs
+ )
def sparse_index_topk(
@@ -103,10 +107,10 @@ def sparse_index_topk(
head = q.shape[-2]
dim = q.shape[-1]
- if head != 32:
+ if head != 64:
raise ValueError(
- f"Unsupported head size: {head}. Sparse index topk fused op currently only \
- supports a head number of 32."
+ f"Unsupported head size: {head}. Sparse index topk fused op "
+ "supports head number of 64 (DSV3.2)."
)
if dim != 128:
raise ValueError("dim must be 128, as we precompute scale inner kernel")
@@ -118,7 +122,7 @@ def sparse_index_topk(
f"q={device}, kv={kv.device}, weights={weights.device}, "
f"logits={logits.device}, profile_logs={profile_logs.device}"
)
- workspace = torch.zeros(seqlen, (200 * 1024 + 258), dtype=torch.int32, device=device)
- torch.ops.tilert.sparse_index_topk_glm5_op(
+ workspace = torch.zeros(seqlen, (200 * 1024 + 260), dtype=torch.int32, device=device)
+ torch.ops.tilert.sparse_index_topk_dsv32_op(
q, kv, weights, logits, cur_pos, indices, workspace, profile_logs
)
diff --git a/tilert/models/deepseek_v3_2/ops/topk.py b/tilert/models/deepseek_v3_2/ops/topk.py
new file mode 100644
index 0000000..49c58eb
--- /dev/null
+++ b/tilert/models/deepseek_v3_2/ops/topk.py
@@ -0,0 +1,171 @@
+"""topk operations module."""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import torch
+import torch.nn as nn
+
+from tilert.utils import get_profile_log_tensor
+
+if TYPE_CHECKING:
+ from tilert.models.deepseek_v3_2.model_args import ModelArgs
+
+
+__all__ = [
+ "TopK",
+ "topk_approximate",
+ "topk_accurate",
+]
+
+
+def topk_approximate(
+ logits: torch.Tensor,
+ seq_len: int,
+ topk: int,
+ profile_logs: torch.Tensor,
+ model_arch: str,
+ compute_kernel_type: str = "general",
+) -> torch.Tensor:
+ """
+ Topk approximate operation.
+
+ Topk approximate the input tensor `logits` and stores the result in `output_raw`.
+
+ Args:
+ logits (torch.Tensor): The input tensor.
+ seq_len (int): valid data of logits.shape[-1]
+ topk (int): The number of topk to approximate.
+ profile_logs (torch.Tensor): The profile logs tensor.
+
+ Returns:
+ indices (torch.Tensor): The output tensor.
+ """
+ if logits.dtype != torch.float32:
+ raise ValueError("logits must be a float32 tensor.")
+
+ if topk != 2048:
+ raise ValueError("topk must be 2048.")
+ batch = logits.shape[0]
+ if batch != 1:
+ raise ValueError("batch must be 1 in this version")
+
+ indices = torch.zeros(batch, topk, dtype=torch.int32, device=logits.device)
+ torch.ops.tilert.topk_approximate_op(
+ logits, indices, seq_len, model_arch, compute_kernel_type, profile_logs
+ )
+
+ return indices
+
+
+def topk_accurate(
+ logits: torch.Tensor,
+ seq_len: int,
+ topk: int,
+ profile_logs: torch.Tensor,
+ model_arch: str,
+ compute_kernel_type: str = "general",
+ ratio: int = 1,
+) -> torch.Tensor:
+ """
+ Topk approximate operation.
+
+ Topk approximate the input tensor `logits` and stores the result in `output_raw`.
+
+ Args:
+ logits (torch.Tensor): The input tensor.
+ seq_len (int): length of last samples,
+ for k=logits.shape[1] samples, the length is
+ seq-k+1, seq-k+2, ..., seq-1, seq
+ topk (int): The number of topk to approximate.
+ profile_logs (torch.Tensor): The profile logs tensor.
+ ratio (int): Token-domain to logits-trailing-dim compression factor.
+ Returns:
+ indices (torch.Tensor): The output tensor.
+ """
+ if logits.dtype != torch.float32:
+ raise ValueError("logits must be a float32 tensor.")
+
+ if topk not in (512, 1024, 2048):
+ raise ValueError("topk must be 512, 1024, or 2048.")
+
+ assert logits.shape[0] == 1, "batch must be 1 in this version"
+ num_samples = logits.shape[1]
+
+ indices = torch.zeros(num_samples, topk, dtype=torch.int32, device=logits.device)
+ indices_ws = torch.zeros(1, num_samples, 4, topk * 2, dtype=torch.int32, device=logits.device)
+ torch.ops.tilert.topk_accurate_op(
+ logits,
+ indices,
+ seq_len - num_samples,
+ indices_ws,
+ model_arch,
+ compute_kernel_type,
+ profile_logs,
+ ratio,
+ )
+
+ return indices
+
+
+class TopK(nn.Module):
+ """TopK operation with optional approximate kernel.
+
+ Wraps topk_accurate / topk_approximate and provides golden_forward
+ (reference implementation) and tilert_forward (TileRT kernel).
+ """
+
+ def __init__(self, use_approximate: bool = False, model_args: ModelArgs | None = None) -> None:
+ super().__init__()
+ self.use_approximate = use_approximate
+ if model_args is None:
+ from tilert.models.deepseek_v3_2.model_args import ModelArgs
+
+ model_args = ModelArgs()
+ self.model_args = model_args
+
+ def golden_forward(
+ self,
+ logits: torch.Tensor,
+ topk: int,
+ ) -> torch.Tensor:
+ """Reference forward: torch.topk on the last dimension.
+
+ Args:
+ logits: Scores tensor, shape (batch, ..., seq_len).
+ topk: Number of top indices to return.
+
+ Returns:
+ Indices of top-k values along the last dimension.
+ """
+ seq_len = logits.shape[-1]
+ return logits.topk(min(topk, seq_len), dim=-1)[1]
+
+ def tilert_forward(
+ self,
+ logits: torch.Tensor,
+ topk: int,
+ ) -> torch.Tensor:
+ """Tilert forward: batch of samples with varying valid length.
+
+ Args:
+ logits: Shape (batch, num_samples, cache_len).
+ topk: Number of top indices to return.
+
+ Returns:
+ Indices tensor of shape (batch, num_samples, topk).
+ """
+ profile_logs = get_profile_log_tensor(device=logits.device)
+ cache_len = logits.shape[-1]
+ if self.use_approximate:
+ indices = topk_approximate(
+ logits, cache_len, topk, profile_logs, model_arch=self.model_args.arch_name
+ )
+ else:
+ indices = topk_accurate(
+ logits, cache_len, topk, profile_logs, model_arch=self.model_args.arch_name
+ )
+ if indices.dim() == 2:
+ return indices.unsqueeze(0)
+ return indices
diff --git a/tilert/models/deepseek_v3_2/ops/unproj_o_allreduce.py b/tilert/models/deepseek_v3_2/ops/unproj_o_allreduce.py
new file mode 100644
index 0000000..33520b5
--- /dev/null
+++ b/tilert/models/deepseek_v3_2/ops/unproj_o_allreduce.py
@@ -0,0 +1,526 @@
+"""UnprojOAllreduce operation module."""
+
+import math
+from dataclasses import dataclass
+from enum import Enum
+
+import torch
+
+from tilert.models.base import TileRTModule, TilertWeightsConverter
+from tilert.models.common import weight_dequant
+from tilert.models.deepseek_v3_2.model_args import ModelArgs
+from tilert.utils import get_profile_log_tensor
+
+__all__ = [
+ "unproj_o_allreduce",
+ "UnProjOAllReduce",
+ "UnProjOAllReduceAlgorithm",
+ "UnProjOAllReduceRefWeightsAlias",
+ "UnProjOAllReduceTilertWeightsAlias",
+]
+
+
+def unproj_o_allreduce(
+ vec_in: torch.Tensor,
+ mat_in: torch.Tensor,
+ mat_scale: torch.Tensor,
+ x_in: torch.Tensor,
+ flag: int,
+ vec_out: torch.Tensor,
+ profile_logs: torch.Tensor,
+ model_arch: str,
+ compute_kernel_type: str = "bf16",
+) -> None:
+ """
+ Fused operation of unprojection and allreduce.
+
+ Args:
+ vec_in: Input tensor.
+ mat_in: Input tensor.
+ mat_scale: Input tensor.
+ x_in: Input tensor.
+ flag: Input flag.
+ vec_out: Output tensor.
+ profile_logs: Profile logs tensor.
+ model_arch: Model architecture ("deepseek_v3_2" or "glm_5").
+ compute_kernel_type: Compute kernel type ("bf16", "fp16mma").
+ """
+ torch.ops.tilert.unproj_o_allreduce_op(
+ vec_in,
+ mat_in,
+ mat_scale,
+ x_in,
+ flag,
+ vec_out,
+ profile_logs,
+ model_arch,
+ compute_kernel_type,
+ )
+
+
+class UnProjOAllReduceAlgorithm(Enum):
+ """UnprojOAllReduce algorithm"""
+
+ FP16MMA = "fp16mma"
+ BF16MMA = "bf16mma"
+
+
+@dataclass
+class UnProjOAllReduceRefWeightsAlias:
+ """Reference weights alias for UnProjOAllReduce."""
+
+ o_proj_weight = "self_attn.o_proj.weight"
+ o_proj_scale_inv = "self_attn.o_proj.weight_scale_inv"
+
+ @property
+ def ref_tensor_alias(self) -> list[str]:
+ return [self.o_proj_weight, self.o_proj_scale_inv]
+
+ def __call__(self) -> list[str]:
+ return self.ref_tensor_alias
+
+
+@dataclass
+class UnProjOAllReduceTilertWeightsAlias:
+ """TileRT weights alias for UnProjOAllReduce."""
+
+ unproj_weights = "unproj_weights"
+ unproj_scales = "unproj_scales"
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return [self.unproj_weights, self.unproj_scales]
+
+ def __call__(self) -> list[str]:
+ return self.tilert_tensor_alias
+
+
+class UnProjOAllReduceWeightsConverter(TilertWeightsConverter):
+ """UnProjOAllReduce weights converter"""
+
+ @staticmethod
+ def _swizzle_mma_16x16(mat_in: torch.Tensor) -> torch.Tensor:
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 16
+ pre_shape = mat_in.shape[:-2]
+ mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 2).transpose(-4, -3).transpose(-5, -4)
+ return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 2).transpose(-3, -2)
+
+ def convert_to_fp16mma_128cta(
+ self,
+ weights_list: list[torch.Tensor],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Convert weights to the FP16 MMA layout for the 128-CTA config."""
+ with torch.inference_mode():
+ mat, scales = weights_list
+ if scales.dtype != torch.float32:
+ scales = scales.to(torch.float32)
+
+ dim = self.model_args.dim
+ block_size = self.model_args.block_size
+ sms = 128
+ vec_dim = mat.shape[-1]
+ dim_per_sm = dim // sms
+ full_tiles = dim_per_sm // 16
+ remainder_rows = dim_per_sm % 16
+ stages = vec_dim // 512
+ vec_scale_dim = vec_dim // block_size
+ scale_per_stage = vec_scale_dim // stages
+
+ dim_scale_dim = dim // block_size
+ scales_per_full_tile = 2 if remainder_rows > 0 else 1
+ rem_scales = 1 if remainder_rows > 0 else 0
+ total_scale_slots = (full_tiles * scales_per_full_tile + rem_scales) * scale_per_stage
+ repeat_factor = 8 if remainder_rows == 0 else 16
+
+ sc = scales.reshape(dim_scale_dim, 1, vec_scale_dim)
+ sc = sc.repeat(1, repeat_factor, 1)
+ scales_per_cta = full_tiles * scales_per_full_tile + rem_scales
+ sc = (
+ sc.reshape(sms, scales_per_cta, stages, scale_per_stage)
+ .transpose(1, 2)
+ .reshape(sms, stages, total_scale_slots)
+ .view(torch.float8_e4m3fn)
+ )
+ sc_packed = sc
+
+ mat_per_sm = mat.reshape(sms, dim_per_sm, vec_dim)
+
+ full_rows = full_tiles * 16
+ mat_full = (
+ mat_per_sm[:, :full_rows, :]
+ .reshape(sms, full_tiles, 16, stages, 512)
+ .transpose(2, 3)
+ .reshape(sms, full_tiles, stages, 16, 32, 16)
+ .transpose(3, 4)
+ .reshape(sms, full_tiles, stages, 32, 16, 16)
+ )
+ mat_full = UnProjOAllReduceWeightsConverter._swizzle_mma_16x16(mat_full)
+ mat_full = mat_full.transpose(1, 2).reshape(sms, stages, -1)
+
+ if remainder_rows > 0:
+ mat_rem_raw = mat_per_sm[:, full_rows:, :]
+ mat_rem_padded = torch.zeros(
+ sms, 16, vec_dim, dtype=mat_rem_raw.dtype, device=mat_rem_raw.device
+ )
+ mat_rem_padded[:, :remainder_rows, :] = mat_rem_raw
+ mat_rem = (
+ mat_rem_padded.reshape(sms, 1, 16, stages, 512)
+ .transpose(2, 3)
+ .reshape(sms, 1, stages, 16, 32, 16)
+ .transpose(3, 4)
+ .reshape(sms, 1, stages, 32, 16, 16)
+ )
+ mat_rem = UnProjOAllReduceWeightsConverter._swizzle_mma_16x16(mat_rem)
+ mat_rem = mat_rem.transpose(1, 2).reshape(sms, stages, -1)
+ mat_combined = torch.cat([mat_full, mat_rem], dim=-1)
+ else:
+ mat_combined = mat_full
+
+ scales_padding = torch.zeros(
+ sms,
+ stages,
+ 128 - sc_packed.shape[-1],
+ dtype=torch.float8_e4m3fn,
+ device=mat.device,
+ )
+ mat_all = torch.cat([mat_combined, sc_packed, scales_padding], dim=-1).contiguous()
+ dummy_scales = torch.zeros(1, dtype=torch.float32, device=mat.device)
+ return mat_all, dummy_scales
+
+ def convert_to_bf16mma(
+ self,
+ weights_list: list[torch.Tensor],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Convert common weights to the BF16 MMA layout."""
+ assert (
+ self.model_args.arch_name == "deepseek_v3_2"
+ ), "BF16 MMA dispatch is wired only for DeepSeek-V3.2 DevGroupB."
+ return self.convert_to_fp16mma_128cta(weights_list)
+
+ def convert_to_fp16mma(
+ self,
+ weights_list: list[torch.Tensor],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Convert common weights to TileRT FP16 MMA layout."""
+ if self.model_args.arch_name == "deepseek_v3_2":
+ return self.convert_to_fp16mma_128cta(weights_list)
+ assert self.model_args.arch_name == "glm_5", "Only GLM-5 and DSV3.2 support FP16 MMA"
+
+ with torch.inference_mode():
+ mat, scales = weights_list
+ if scales.dtype != torch.float32:
+ print(
+ "Warning: UnProjOAllReduceWeightsConverter: "
+ + f"scales.dtype: {scales.dtype} "
+ + "is not float32, convert to float32."
+ )
+ scales = scales.to(torch.float32)
+
+ dim = self.model_args.dim
+ block_size = self.model_args.block_size
+ sms = 128
+ vec_dim = mat.shape[-1]
+ dim_per_sm = dim // sms
+ tiles_per_stage = dim_per_sm // 16
+ stages = vec_dim // 512
+ dim_scale_dim = dim // block_size
+ vec_scale_dim = vec_dim // block_size
+ scale_per_stage = vec_scale_dim // stages
+
+ scales = scales.reshape(dim_scale_dim, 1, vec_scale_dim)
+ scales = scales.repeat(1, 8, 1)
+ scales = (
+ scales.reshape(sms, tiles_per_stage, stages, scale_per_stage)
+ .transpose(1, 2)
+ .reshape(sms, stages, tiles_per_stage * scale_per_stage)
+ .view(torch.float8_e4m3fn)
+ )
+
+ mat = (
+ mat.reshape(sms, dim_per_sm, vec_dim)
+ .reshape(sms, tiles_per_stage, 16, stages, 512)
+ .transpose(2, 3)
+ .reshape(sms, tiles_per_stage, stages, 16, 32, 16)
+ .transpose(3, 4)
+ .reshape(sms, tiles_per_stage, stages, 32, 16, 16)
+ )
+ mat = UnProjOAllReduceWeightsConverter._swizzle_mma_16x16(mat)
+ mat = mat.transpose(1, 2).reshape(sms, stages, -1)
+
+ scales_padding = torch.zeros(
+ sms,
+ stages,
+ 128 - scales.shape[-1],
+ dtype=torch.float8_e4m3fn,
+ device=mat.device,
+ )
+ mat_full = torch.cat([mat, scales, scales_padding], dim=-1).contiguous()
+ dummy_scales = torch.zeros(1, dtype=torch.float32, device=mat.device)
+ return mat_full, dummy_scales
+
+
+class UnProjOAllReduce(TileRTModule):
+ """UnProjOAllReduce module"""
+
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [
+ UnProjOAllReduceAlgorithm.FP16MMA,
+ UnProjOAllReduceAlgorithm.BF16MMA,
+ ],
+ "glm_5": [
+ UnProjOAllReduceAlgorithm.FP16MMA,
+ ],
+ }
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ num_devices: int,
+ device_id: int = 0,
+ ref_weights_alias: UnProjOAllReduceRefWeightsAlias | None = None,
+ tilert_weights_alias: UnProjOAllReduceTilertWeightsAlias | None = None,
+ algorithm: UnProjOAllReduceAlgorithm = UnProjOAllReduceAlgorithm.FP16MMA,
+ ):
+ super().__init__(
+ self.__class__.__name__,
+ model_args=model_args,
+ num_devices=num_devices,
+ device_id=device_id,
+ )
+
+ self.tilert_weights_alias = (
+ tilert_weights_alias
+ if tilert_weights_alias is not None
+ else UnProjOAllReduceTilertWeightsAlias()
+ )
+ self.ref_weights_alias = (
+ ref_weights_alias
+ if ref_weights_alias is not None
+ else UnProjOAllReduceRefWeightsAlias()
+ )
+
+ self.arch_name = self.model_args.arch_name
+ self.dim = self.model_args.dim
+ self.n_heads = self.model_args.n_heads
+ self.head_dim = self.model_args.v_head_dim
+
+ if self.n_heads % self.num_devices == 0:
+ self.num_local_heads = self.n_heads // self.num_devices
+ else:
+ n_local = math.ceil(self.n_heads / self.num_devices)
+ if n_local % 2 != 0:
+ n_local += 1
+ self.num_local_heads = n_local
+
+ self.block_size = self.model_args.block_size
+ self.algorithm: UnProjOAllReduceAlgorithm = algorithm
+
+ self.ref_unproj_o: torch.Tensor | None = None
+
+ self.tilert_weights: torch.Tensor | None = None
+ self.tilert_scales: torch.Tensor | None = None
+
+ self.hidden_out: torch.Tensor | None = None
+
+ self.profile_logs: torch.Tensor | None = None
+ self.is_var_init = False
+
+ def get_weights_list(self) -> list[torch.Tensor]:
+ """
+ Get the weights list.
+
+ Returns:
+ List of weights.
+ """
+ return [self.tilert_weights, self.tilert_scales]
+
+ def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ """
+ Device sharding.
+
+ Args:
+ weights_map: Map from ref weight alias to tensor (full model).
+
+ Returns:
+ Map from tilert weight alias to (num_devices, ...) tensors.
+ """
+ unproj_o_weight = weights_map[self.ref_weights_alias.o_proj_weight]
+ unproj_o_scale = weights_map[self.ref_weights_alias.o_proj_scale_inv]
+
+ if self.n_heads % self.num_devices == 0:
+ unproj_o_weight = unproj_o_weight.reshape(self.dim, self.num_devices, -1)
+ unproj_o_weight = unproj_o_weight.transpose(0, 1)
+ unproj_o_scale = unproj_o_scale.reshape(
+ self.dim // self.block_size, self.num_devices, -1
+ )
+ unproj_o_scale = unproj_o_scale.transpose(0, 1)
+ else:
+ cols_per_head = self.head_dim
+ cols_per_dev = self.num_local_heads * cols_per_head
+ W = unproj_o_weight.view(self.dim, self.n_heads, cols_per_head)
+
+ scale_cols_per_head = cols_per_head // self.block_size
+ scale_cols_per_dev = self.num_local_heads * scale_cols_per_head
+ S = unproj_o_scale.view(self.dim // self.block_size, self.n_heads, scale_cols_per_head)
+
+ W_devs = []
+ S_devs = []
+ for dev in range(self.num_devices):
+ start = dev * self.num_local_heads
+ end = min(self.n_heads, start + self.num_local_heads)
+ real = max(0, end - start)
+
+ dev_W = torch.zeros(
+ self.dim,
+ self.num_local_heads,
+ cols_per_head,
+ dtype=W.dtype,
+ device=W.device,
+ )
+ if real > 0:
+ dev_W[:, :real] = W[:, start:end]
+ W_devs.append(dev_W.reshape(self.dim, cols_per_dev))
+
+ dev_S = torch.zeros(
+ self.dim // self.block_size,
+ self.num_local_heads,
+ scale_cols_per_head,
+ dtype=S.dtype,
+ device=S.device,
+ )
+ if real > 0:
+ dev_S[:, :real] = S[:, start:end]
+ S_devs.append(dev_S.reshape(self.dim // self.block_size, scale_cols_per_dev))
+
+ unproj_o_weight = torch.stack(W_devs, dim=0)
+ unproj_o_scale = torch.stack(S_devs, dim=0)
+
+ return {
+ self.tilert_weights_alias.unproj_weights: unproj_o_weight.contiguous(),
+ self.tilert_weights_alias.unproj_scales: unproj_o_scale.contiguous(),
+ }
+
+ def init_reference_weights(
+ self,
+ state_dict: dict[str, torch.Tensor],
+ device_id: int | None = None,
+ ) -> None:
+ """
+ Initialize the reference weights.
+
+ Args:
+ state_dict: State dictionary keyed by ref weight alias (full model).
+ device_id: Device ID for this shard; defaults to self.device_id.
+ """
+ did = self.device_id if device_id is None else device_id
+ sharded = self.device_sharding(state_dict)
+ weights = sharded[self.tilert_weights_alias.unproj_weights][did]
+ scales = sharded[self.tilert_weights_alias.unproj_scales][did]
+ self.ref_unproj_o = weight_dequant(weights, scales)
+
+ def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ """
+ Initialize the tilert weights.
+
+ Args:
+ state_dict: State dictionary keyed by tilert weight alias (per-device).
+ """
+ assert self.algorithm is not None, "Algorithm is not set"
+ self.tilert_weights, self.tilert_scales = UnProjOAllReduceWeightsConverter(
+ self.model_args, self.num_devices
+ ).dispatch(
+ self.algorithm,
+ [state_dict[alias] for alias in self.tilert_weights_alias()],
+ )
+
+ def init_tilert_vars(self, batch_size: int, seq_len: int) -> None:
+ """
+ Initialize the tilert variables.
+
+ Args:
+ batch_size: Batch size.
+ seq_len: Sequence length.
+ """
+ self.hidden_out = torch.zeros(
+ (batch_size, seq_len, self.dim),
+ dtype=torch.bfloat16,
+ device=f"cuda:{self.device_id}",
+ )
+ self.profile_logs = get_profile_log_tensor(device=f"cuda:{self.device_id}")
+ self.is_var_init = True
+
+ def init_random_weights(self) -> None:
+ """Initialize the random weights."""
+ unproj_o_weights = torch.randn(
+ self.dim,
+ self.n_heads * self.head_dim,
+ dtype=torch.bfloat16,
+ device=f"cuda:{self.device_id}",
+ ).to(torch.float8_e4m3fn)
+
+ head_scale_dim = self.head_dim // self.block_size
+ dim_scale_dim = self.dim // self.block_size
+ scale_dtype = torch.float32 if self.arch_name == "glm_5" else torch.bfloat16
+ unproj_o_scales = torch.randn(
+ dim_scale_dim,
+ self.n_heads * head_scale_dim,
+ dtype=scale_dtype,
+ device=f"cuda:{self.device_id}",
+ )
+ ref_state_dict = {
+ self.ref_weights_alias.o_proj_weight: unproj_o_weights,
+ self.ref_weights_alias.o_proj_scale_inv: unproj_o_scales,
+ }
+
+ self.init_reference_weights(ref_state_dict)
+ sharded = self.device_sharding(ref_state_dict)
+ per_device_state = {k: v[self.device_id] for k, v in sharded.items()}
+ self.init_tilert_weights(per_device_state)
+
+ def golden_forward(
+ self,
+ vec_in: torch.Tensor,
+ ) -> torch.Tensor:
+ """
+ Forward pass for the down-project module.
+
+ Args:
+ vec_in: Input vector.
+
+ Returns:
+ Output tensor.
+ """
+ assert self.ref_unproj_o is not None
+ bsz = vec_in.shape[0]
+ seq_len = vec_in.shape[1]
+ assert bsz == 1
+ res = vec_in.reshape(bsz, seq_len, -1).float() @ self.ref_unproj_o.T.float()
+ return res.to(torch.bfloat16)
+
+ def tilert_forward(
+ self,
+ vec_in: torch.Tensor,
+ x_in: torch.Tensor,
+ flag: int,
+ ) -> torch.Tensor:
+ assert self.hidden_out is not None
+ assert self.profile_logs is not None
+ assert self.algorithm is not None
+ unproj_o_allreduce(
+ vec_in,
+ self.tilert_weights,
+ self.tilert_scales,
+ x_in,
+ flag,
+ self.hidden_out,
+ self.profile_logs,
+ model_arch=self.model_args.arch_name,
+ compute_kernel_type=self.algorithm.value,
+ )
+ return self.hidden_out
+
+ def __call__(
+ self,
+ vec_in: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.golden_forward(vec_in)
diff --git a/python/models/deepseek_v3_2/refs/__init__.py b/tilert/models/deepseek_v3_2/refs/__init__.py
similarity index 61%
rename from python/models/deepseek_v3_2/refs/__init__.py
rename to tilert/models/deepseek_v3_2/refs/__init__.py
index 25e6872..75aaf30 100644
--- a/python/models/deepseek_v3_2/refs/__init__.py
+++ b/tilert/models/deepseek_v3_2/refs/__init__.py
@@ -2,6 +2,10 @@
This package exposes helpers like `act_quant`, `fp8_gemm`, and `weight_dequant`
for tests and higher-level Python ops.
+
+Note: `act_quant` and `fp8_gemm` require tilelang at *call* time, and
+`weight_dequant` requires triton at *call* time, but importing this package
+does not require tilelang or triton to be installed.
"""
from .kernel import act_quant, fp8_gemm, weight_dequant
diff --git a/tilert/models/deepseek_v3_2/refs/kernel.py b/tilert/models/deepseek_v3_2/refs/kernel.py
new file mode 100644
index 0000000..cd68a7c
--- /dev/null
+++ b/tilert/models/deepseek_v3_2/refs/kernel.py
@@ -0,0 +1,306 @@
+import torch
+
+try:
+ import tilelang
+ import tilelang.language as T
+
+ _HAS_TILELANG = True
+except ImportError:
+ _HAS_TILELANG = False
+
+try:
+ import triton
+ import triton.language as tl
+
+ _HAS_TRITON = True
+except ImportError:
+ _HAS_TRITON = False
+
+__all__ = [
+ "weight_dequant",
+ "act_quant",
+ "fp8_gemm",
+]
+
+FP8 = "float8_e4m3"
+BF16 = "bfloat16"
+FP32 = "float32"
+
+
+def _require_tilelang(fn_name: str) -> None:
+ if not _HAS_TILELANG:
+ raise ImportError(f"{fn_name} requires tilelang. Install with: pip install tilelang")
+
+
+def _require_triton(fn_name: str) -> None:
+ if not _HAS_TRITON:
+ raise ImportError(f"{fn_name} requires triton. Install with: pip install triton")
+
+
+if _HAS_TRITON:
+
+ @triton.jit
+ def weight_dequant_kernel( # type: ignore
+ x_ptr,
+ s_ptr,
+ y_ptr,
+ M_Size: tl.constexpr,
+ N_Size: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr,
+ ) -> None:
+ """
+ Weight dequantization kernel.
+
+ Dequantizes weights using the provided scaling factors and stores the
+ result.
+
+ Args:
+ x_ptr (tl.pointer): Pointer to the quantized weights.
+ s_ptr (tl.pointer): Pointer to the scaling factors.
+ y_ptr (tl.pointer): Pointer to the output buffer for dequantized
+ weights.
+ M (int): Number of rows in the weight matrix.
+ N (int): Number of columns in the weight matrix.
+ BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
+
+ Returns:
+ None
+ """
+ pid_m = tl.program_id(axis=0)
+ pid_n = tl.program_id(axis=1)
+ n_size = tl.cdiv(N_Size, BLOCK_SIZE)
+ offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+ offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+ offs = offs_m[:, None] * N_Size + offs_n[None, :]
+ mask = (offs_m[:, None] < M_Size) & (offs_n[None, :] < N_Size)
+ x_in = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
+ s_in = tl.load(s_ptr + pid_m * n_size + pid_n)
+ y_out = x_in * s_in
+ tl.store(y_ptr + offs, y_out, mask=mask)
+
+
+def _weight_dequant_torch(
+ x_in: torch.Tensor, s_in: torch.Tensor, block_size: int = 128
+) -> torch.Tensor:
+ """Pure-PyTorch fallback for weight_dequant (multi-GPU safe).
+
+ Used when triton is unavailable, or when the triton kernel raises at
+ launch time (e.g. ``cuPointerGetAttribute`` failing on non-device-0
+ GPUs during multi-device ``init_random_weights``).
+ """
+ M, N = x_in.shape
+ y = x_in.float().reshape(M // block_size, block_size, N // block_size, block_size)
+ y = y * s_in[:, None, :, None]
+ return y.reshape(M, N).to(torch.get_default_dtype())
+
+
+def weight_dequant(x_in: torch.Tensor, s_in: torch.Tensor, block_size: int = 128) -> torch.Tensor:
+ """
+ Dequantizes the given weight tensor using the provided scale tensor.
+
+ Args:
+ x_in (torch.Tensor): The quantized weight tensor of shape (M, N).
+ s_in (torch.Tensor): The scale tensor of shape (M//block_size,
+ N//block_size).
+ block_size (int, optional): The block size to use for dequantization.
+ Defaults to 128.
+
+ Returns:
+ torch.Tensor: The dequantized weight tensor of the same shape as `x`.
+
+ Raises:
+ AssertionError: If `x` or `s` are not contiguous or if their dimensions
+ are not 2.
+ """
+ assert x_in.is_contiguous() and s_in.is_contiguous(), "Input tensors must be contiguous"
+ assert x_in.dim() == 2 and s_in.dim() == 2, "Input tensors must have 2 dimensions"
+ if not _HAS_TRITON:
+ return _weight_dequant_torch(x_in, s_in, block_size)
+ M_Size, N_Size = x_in.size()
+ grid = lambda meta: ( # noqa: E731
+ triton.cdiv(M_Size, meta["BLOCK_SIZE"]),
+ triton.cdiv(N_Size, meta["BLOCK_SIZE"]),
+ )
+ try:
+ y_out = torch.empty_like(x_in, dtype=torch.get_default_dtype())
+ weight_dequant_kernel[grid](x_in, s_in, y_out, M_Size, N_Size, BLOCK_SIZE=block_size)
+ except (ValueError, RuntimeError):
+ return _weight_dequant_torch(x_in, s_in, block_size)
+ return y_out
+
+
+if _HAS_TILELANG:
+ tilelang.set_log_level("WARNING")
+
+ _pass_configs = {
+ tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
+ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
+ }
+
+ def _fast_log2_ceil(x): # type: ignore
+ bits_x = T.reinterpret("uint32", x)
+ exp_x = (bits_x >> 23) & 0xFF
+ man_bits = bits_x & ((1 << 23) - 1)
+ return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0))
+
+ def _fast_pow2(x): # type: ignore
+ bits_x = (x + 127) << 23
+ return T.reinterpret("float32", bits_x)
+
+ def _fast_round_scale(amax, fp8_max_inv): # type: ignore
+ return _fast_pow2(_fast_log2_ceil(amax * fp8_max_inv))
+
+ @tilelang.jit(pass_configs=_pass_configs)
+ def act_quant_kernel( # type: ignore
+ N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False # type: ignore
+ ): # type: ignore
+ M = T.symbolic("M")
+ fp8_min = -448.0
+ fp8_max = 448.0
+ fp8_max_inv = 1 / fp8_max
+ num_stages = 0 if round_scale else 2
+ blk_m = 32
+ group_size = 128
+
+ @T.prim_func
+ def act_quant_kernel_( # type: ignore
+ X: T.Tensor[(M, N), in_dtype],
+ Y: T.Tensor[(M, N), out_dtype],
+ S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
+ ): # type: ignore
+ with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
+ pid_m,
+ pid_n,
+ ):
+ x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
+ x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
+ amax_local = T.alloc_fragment((blk_m,), scale_dtype)
+ s_local = T.alloc_fragment((blk_m,), scale_dtype)
+ y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
+ y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
+
+ for _ in T.Pipelined(1, num_stages=num_stages):
+ T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
+ T.copy(x_shared, x_local)
+ T.reduce_absmax(x_local, amax_local, dim=1)
+ for i in T.Parallel(blk_m):
+ amax_local[i] = T.max(amax_local[i], 1e-4)
+ if round_scale:
+ s_local[i] = _fast_round_scale(amax_local[i], fp8_max_inv)
+ else:
+ s_local[i] = amax_local[i] * fp8_max_inv
+ for i, j in T.Parallel(blk_m, group_size):
+ y_local[i, j] = T.clamp(x_local[i, j] / s_local[i], fp8_min, fp8_max)
+ for i in T.Parallel(blk_m):
+ S[pid_m * blk_m + i, pid_n] = s_local[i]
+ T.copy(y_local, y_shared)
+ T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
+
+ return act_quant_kernel_
+
+ @tilelang.jit(pass_configs=_pass_configs)
+ def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype="float32"): # type: ignore
+ assert out_dtype in [BF16, "float32"]
+
+ M = T.symbolic("M")
+ group_size = 128
+ block_M = 32
+ block_N = 128
+ block_K = 128
+
+ @T.prim_func
+ def fp8_gemm_kernel_( # type: ignore
+ A: T.Tensor[(M, K), FP8],
+ B: T.Tensor[(N, K), FP8],
+ C: T.Tensor[(M, N), out_dtype],
+ scales_a: T.Tensor[(M, T.ceildiv(K, group_size)), FP32],
+ scales_b: T.Tensor[(T.ceildiv(N, group_size), T.ceildiv(K, group_size)), FP32],
+ ): # type: ignore
+ with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
+ bx,
+ by,
+ ):
+ A_shared = T.alloc_shared((block_M, block_K), FP8)
+ B_shared = T.alloc_shared((block_N, block_K), FP8)
+ C_shared = T.alloc_shared((block_M, block_N), out_dtype)
+ Scale_C_shared = T.alloc_shared((block_M), FP32)
+ C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
+ C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype)
+
+ T.use_swizzle(panel_size=10)
+
+ T.clear(C_local)
+ T.clear(C_local_accum)
+ K_iters = T.ceildiv(K, block_K)
+ for k in T.Pipelined(K_iters, num_stages=4):
+ T.copy(A[by * block_M, k * block_K], A_shared)
+ T.copy(B[bx * block_N, k * block_K], B_shared)
+ Scale_B = scales_b[bx * block_N // group_size, k]
+ for i in T.Parallel(block_M):
+ Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B
+
+ T.gemm(A_shared, B_shared, C_local, transpose_B=True)
+ for i, j in T.Parallel(block_M, block_N):
+ C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
+ T.clear(C_local)
+ T.copy(C_local_accum, C_shared)
+ T.copy(C_shared, C[by * block_M, bx * block_N])
+
+ return fp8_gemm_kernel_
+
+
+def act_quant(
+ x: torch.Tensor, block_size: int = 128, scale_fmt: str | None = None
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Quantizes the input tensor `x` using block-wise quantization.
+
+ Args:
+ x (torch.Tensor): The input tensor to be quantized.
+ Must be contiguous and its last dimension size must be divisible by `block_size`.
+ block_size (int, optional): The size of the blocks to be used for quantization.
+ Default is 128.
+ scale_fmt (Optional[str], optional): The format of the scale. Default is None.
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
+ - The quantized tensor with dtype `torch.float8_e4m3fn`.
+ - A tensor of scaling factors with dtype `torch.float32`.
+ """
+ _require_tilelang("act_quant")
+ assert x.is_contiguous(), "Input tensor must be contiguous"
+ assert (
+ x.size(-1) % block_size == 0
+ ), f"Last dimension size must be divisible by block_size (block_size={block_size})"
+ N = x.size(-1)
+ y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
+ s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32)
+ kernel = act_quant_kernel(N, round_scale=scale_fmt is not None)
+ kernel(x.view(-1, N), y.view(-1, N), s.view(-1, N // block_size))
+ return y, s
+
+
+def fp8_gemm(
+ a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor
+) -> torch.Tensor:
+ """
+ Perform a matrix multiplication using FP8 precision.
+
+ Args:
+ a (torch.Tensor): The first input matrix, must be contiguous.
+ a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
+ b (torch.Tensor): The second input matrix, must be contiguous.
+ b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
+
+ Returns:
+ torch.Tensor: The result of the matrix multiplication.
+ """
+ _require_tilelang("fp8_gemm")
+ assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
+ assert a_s.is_contiguous() and b_s.is_contiguous(), "Scaling factor tensors must be contiguous"
+ K = a.size(-1)
+ M = a.numel() // K
+ N = b.size(0)
+ c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
+ kernel = fp8_gemm_kernel(N, K)
+ kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s)
+ return c
diff --git a/python/models/deepseek_v3_2/temp_var_indices.py b/tilert/models/deepseek_v3_2/temp_var_indices.py
similarity index 72%
rename from python/models/deepseek_v3_2/temp_var_indices.py
rename to tilert/models/deepseek_v3_2/temp_var_indices.py
index 552fa3f..a4eca34 100644
--- a/python/models/deepseek_v3_2/temp_var_indices.py
+++ b/tilert/models/deepseek_v3_2/temp_var_indices.py
@@ -1,8 +1,6 @@
"""Named indices for DSA temporary variables.
-Mirrors the C++ ``DsaTempVars`` constants defined in
-``include/lib/models/deepseek_v3_2/helper.hpp`` so that Python code can
-reference temp_vars by name instead of magic numbers.
+Lets Python code reference temp_vars by name instead of magic numbers.
Usage::
@@ -15,7 +13,7 @@
class DsaTempVarIdx(IntEnum):
- """Index constants for DSA temp_vars, mirroring C++ DsaTempVars."""
+ """Index constants for DSA temp_vars."""
Q = 0
KV = 1
@@ -28,7 +26,7 @@ class DsaTempVarIdx(IntEnum):
IDX_LOGITS = 8
IDX_SELECTS = 9
Q_NOPE = 10
- O = 11 # noqa: E741 — mirrors C++ DsaTempVars::O
+ O = 11 # noqa: E741
O_ACC = 12
O_LSE = 13
O_LSE_ACC = 14
@@ -68,35 +66,36 @@ class DsaTempVarIdx(IntEnum):
SAMPLING_CONFIG = 48
TOP_P_SCORES = 49
TOP_P_DEBUG = 50
+ LORA_SLOT_ID = 51
+ LORA_RANK = 52
+ TOP_N_LOG_PROBS = 53
+ TOP_N_INDICES = 54
+ LOGPROBS_FLAG = 55
-# Sentinel: total number of temp vars. Must equal C++ DsaTempVars::temp_vars_size.
-TEMP_VARS_SIZE = 51
+TEMP_VARS_SIZE = 56
-# Short alias for convenient access
Idx = DsaTempVarIdx
def validate_temp_vars_layout() -> None:
- """Validate that the Python enum matches the C++ DsaTempVars layout.
+ """Validate the temporary-variable index enum.
Checks:
1. Enum member count equals TEMP_VARS_SIZE.
2. Indices are contiguous 0..TEMP_VARS_SIZE-1 with no gaps or duplicates.
- 3. (If libtilert.so is loaded) C++ temp_vars_size matches Python TEMP_VARS_SIZE.
+ 3. (If the backend is loaded) the backend temp_vars_size matches TEMP_VARS_SIZE.
Raises:
RuntimeError: If any validation check fails.
"""
members = list(DsaTempVarIdx)
- # Check member count
if len(members) != TEMP_VARS_SIZE:
raise RuntimeError(
f"DsaTempVarIdx has {len(members)} members but TEMP_VARS_SIZE={TEMP_VARS_SIZE}"
)
- # Check contiguous indices
indices = sorted(m.value for m in members)
expected = list(range(TEMP_VARS_SIZE))
if indices != expected:
@@ -107,16 +106,13 @@ def validate_temp_vars_layout() -> None:
f"Missing: {missing}, Duplicates: {set(dupes)}"
)
- # Check against C++ if the library is loaded
try:
import torch
cpp_size = torch.ops.tilert.dsa_temp_vars_size()
if cpp_size != TEMP_VARS_SIZE:
raise RuntimeError(
- f"Python TEMP_VARS_SIZE={TEMP_VARS_SIZE} != "
- f"C++ DsaTempVars::temp_vars_size={cpp_size}"
+ f"TEMP_VARS_SIZE={TEMP_VARS_SIZE} != " f"backend temp_vars_size={cpp_size}"
)
except (AttributeError, RuntimeError):
- # Library not loaded or op not available — skip C++ check
pass
diff --git a/python/models/glm_5/__init__.py b/tilert/models/glm_5/__init__.py
similarity index 100%
rename from python/models/glm_5/__init__.py
rename to tilert/models/glm_5/__init__.py
diff --git a/tilert/models/glm_5/_dsa_v32/__init__.py b/tilert/models/glm_5/_dsa_v32/__init__.py
new file mode 100644
index 0000000..4b8633b
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/__init__.py
@@ -0,0 +1 @@
+"""DeepSeek v3.2 model package."""
diff --git a/tilert/models/glm_5/_dsa_v32/generator.py b/tilert/models/glm_5/_dsa_v32/generator.py
new file mode 100644
index 0000000..26ee685
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/generator.py
@@ -0,0 +1,531 @@
+"""DSA show hands for deepseek v3.2."""
+
+import math
+import time
+
+import torch
+from transformers import AutoTokenizer
+
+from tilert import logger
+from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
+from tilert.models.glm_5._dsa_v32.modules.end2end import ShowHandsDSALayer
+from tilert.models.glm_5._dsa_v32.temp_var_indices import Idx
+from tilert.tilert_init import tilert_init
+
+__all__ = [
+ "DSAv32Generator",
+ "stats_time",
+]
+
+
+def stats_time(time_list: list[float], title: str) -> None:
+ if len(time_list) > 0:
+ avg_time = sum(time_list) / len(time_list)
+ std_dev = math.sqrt(sum((x - avg_time) ** 2 for x in time_list) / len(time_list))
+ logger.info(title)
+ logger.info(f"--Average time taken to generate token: {avg_time * 1000:.4f} ms")
+ logger.info(f"--Standard deviation of time: {std_dev * 1000:.4f} ms")
+ logger.info(f"--Effective tokens per second: {1 / avg_time:.4f}")
+
+
+class DSAv32Generator:
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ max_new_tokens: int = 100,
+ temperature: float = 1.0,
+ model_weights_dir: str = "",
+ with_mtp: bool = False,
+ use_topp: bool = False,
+ top_p: float = 0.9,
+ top_k: int = 256,
+ sampling_seed: int = 42,
+ ):
+ """Initialize the DSAv32Generator.
+
+ Args:
+ max_new_tokens: Maximum number of new tokens to generate. Defaults to 100.
+ temperature: Temperature for sampling. Defaults to 1.0.
+ model_weights_dir: Path of the model weights directory.
+ with_mtp: Whether to use MTP (Multi-Token Prediction) for speculative decoding.
+ use_topp: Whether to use top-p (nucleus) sampling instead of top-1 (argmax).
+ top_p: Top-p threshold for nucleus sampling. Defaults to 0.9.
+ top_k: Number of top-k candidates for top-p sampling. Defaults to 256.
+ sampling_seed: Sampling seed for top-p (fixed per request). Defaults to 42.
+ """
+ torch.set_num_threads(64)
+ self.model_weights_dir = model_weights_dir
+
+ self.max_new_tokens = max_new_tokens
+ self.temperature = temperature
+ self.with_mtp = with_mtp
+ self.use_topp = use_topp
+ self.top_p = top_p
+ self.top_k = top_k
+ self.sampling_seed = sampling_seed
+
+ self.config = model_args
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ self.model_weights_dir, trust_remote_code=True
+ ) # nosec B615
+ self.eos_id = self.tokenizer.eos_token_id
+ self.batch_size = 1
+
+ self.default_device = torch.device("cuda:0")
+
+ self.decode_layer = ShowHandsDSALayer(
+ model_args=self.config,
+ model_path=self.model_weights_dir,
+ with_mtp=with_mtp,
+ use_topp=use_topp,
+ top_p=top_p,
+ top_k=top_k,
+ )
+
+ self.mtp_seq_len = 4 if with_mtp else 1
+
+ def init(self) -> None:
+ """Initialize the ShowHandsGenerator."""
+ tilert_init()
+
+ def cleanup(self) -> None:
+ """Cleanup the ShowHandsGenerator."""
+ self.decode_layer.cleanup()
+
+ def init_random_weights(self) -> None:
+ """Random initialize the weights."""
+ self.decode_layer.init_random_weights()
+
+ def from_pretrained(self) -> None:
+ """Load the model weights from the given path."""
+ self.decode_layer.from_pretrained(self.model_weights_dir)
+
+ def extract_ffn_cache(self) -> tuple[dict[int, list], dict[int, set[str]]]:
+ """Extract MOE/MLP op objects and skip keys from current loaded weights.
+
+ Returns:
+ Tuple of (cached_ffn_ops_per_device, skip_keys_per_device).
+ """
+ from tilert.models.glm_5._dsa_v32.modules.end2end import (
+ _extract_ffn_ops,
+ _get_moe_weight_keys,
+ )
+
+ cached_ffn_ops: dict[int, list] = {}
+ skip_keys: dict[int, set[str]] = {}
+ for device_id in range(self.decode_layer.num_devices):
+ dsa = self.decode_layer._dsa_objects[device_id]
+ if dsa is None:
+ raise RuntimeError(f"Device {device_id} Dsa not available for cache extraction")
+ cached_ffn_ops[device_id] = _extract_ffn_ops(dsa)
+ skip_keys[device_id] = _get_moe_weight_keys(dsa)
+ return cached_ffn_ops, skip_keys
+
+ def from_pretrained_with_cache(
+ self,
+ cached_ffn_ops_per_device: dict[int, list],
+ skip_keys_per_device: dict[int, set[str]],
+ ) -> None:
+ """Load weights reusing cached MOE/MLP ops."""
+ self.decode_layer.from_pretrained_with_cache(
+ self.model_weights_dir, cached_ffn_ops_per_device, skip_keys_per_device
+ )
+
+ def update_sampling_params(
+ self,
+ temperature: float = 1.0,
+ top_p: float = 0.95,
+ top_k: int = 256,
+ use_topp: bool = True,
+ ) -> None:
+ """Update sampling parameters for the next generation."""
+ self.temperature = temperature
+ self.use_topp = use_topp
+ self.top_p = top_p
+ self.top_k = top_k
+ self.decode_layer.update_sampling_config(
+ temperature=temperature, top_p=top_p, top_k=top_k, use_topp=use_topp
+ )
+
+ @torch.inference_mode()
+ def generate(
+ self,
+ prompt: str,
+ print_log: bool = True,
+ with_mtp: bool | None = None,
+ prompt_tokens: list[int] | None = None,
+ ) -> tuple[str, list[float], list[int], int]:
+ """Main function to load the model and perform single sequence generation.
+
+ Args:
+ prompt: The input prompt string.
+ print_log: Whether to print generation logs.
+ with_mtp: Override MTP mode for this call. None uses self.with_mtp.
+ Requires MTP weights to have been loaded (self.with_mtp=True).
+ prompt_tokens: Pre-tokenized prompt tokens. If provided, skip tokenization
+ and use these tokens directly (useful for exact-length benchmarking).
+
+ Returns:
+ Tuple of (result_text, time_list, accepted_counts, prompt_len).
+ accepted_counts is empty for non-MTP mode.
+ """
+ active_mtp = with_mtp if with_mtp is not None else self.with_mtp
+ if active_mtp and not self.with_mtp:
+ raise ValueError("Cannot use MTP mode: MTP weights were not loaded")
+ self.decode_layer.set_sampling_seed(self.sampling_seed, with_mtp=active_mtp)
+ if active_mtp:
+ return self._generate_with_mtp(prompt, print_log, prompt_tokens=prompt_tokens)
+ result, time_list, prompt_len = self._generate_without_mtp(
+ prompt, print_log, with_mtp=active_mtp, prompt_tokens=prompt_tokens
+ )
+ return result, time_list, [], prompt_len
+
+ def _generate_without_mtp(
+ self,
+ prompt: str,
+ print_log: bool = True,
+ with_mtp: bool = False,
+ prompt_tokens: list[int] | None = None,
+ ) -> tuple[str, list[float], int]:
+ """Standard generation without MTP."""
+ if prompt_tokens is None:
+ prompt_tokens = self.tokenizer.apply_chat_template(
+ [{"role": "user", "content": prompt}], add_generation_prompt=True
+ )
+
+ max_seq_len = self.config.max_seq_len
+ prompt_len = len(prompt_tokens)
+ total_len = min(max_seq_len, self.max_new_tokens + prompt_len)
+
+ tokens = torch.full(
+ (self.batch_size, total_len), -1, dtype=torch.long, device=self.default_device
+ )
+ tokens[0, :prompt_len] = torch.tensor(
+ prompt_tokens, dtype=torch.long, device=self.default_device
+ )
+ prompt_mask = tokens != -1
+
+ prev_pos = 0
+ finished = torch.tensor(
+ [False] * self.batch_size, dtype=torch.bool, device=self.default_device
+ )
+
+ time_list = []
+ for cur_pos_val in range(1, total_len):
+ start_time = time.time()
+ multi_devices_results = self.decode_layer.forward(
+ tokens[0, prev_pos], with_mtp=with_mtp
+ )
+ end_time = time.time()
+ time_list.append(end_time - start_time)
+
+ intermediates, *_ = multi_devices_results[0]
+ next_token = intermediates[Idx.TOKEN_OUT][0][0]
+
+ next_token = torch.where(
+ prompt_mask[0, cur_pos_val], tokens[0, cur_pos_val], next_token
+ )
+ tokens[0, cur_pos_val] = next_token
+ finished |= torch.logical_and(~prompt_mask[0, cur_pos_val], next_token == self.eos_id)
+ prev_pos = cur_pos_val
+ if cur_pos_val >= prompt_len:
+ decoded_tokens = self.tokenizer.decode(
+ [next_token.item()], skip_special_tokens=True
+ )
+ if print_log:
+ print(decoded_tokens, end="", flush=True)
+
+ if finished.all():
+ break
+
+ if print_log:
+ print("\n")
+ logger.info(f"--Number of tokens generated: {len(time_list)}")
+
+ stats_time(time_list, "==== Performance ====")
+ print("\n")
+
+ self.decode_layer.reset_sequence()
+
+ completion_tokens = []
+ for _, toks in enumerate(tokens.tolist()):
+ toks = toks[prompt_len : prompt_len + self.max_new_tokens]
+ if self.eos_id in toks:
+ toks = toks[: toks.index(self.eos_id)]
+ completion_tokens.append(toks)
+
+ decoded_tokens = self.tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
+
+ return f"{decoded_tokens[0]}\n" if decoded_tokens else "", time_list, prompt_len
+
+ def _generate_with_mtp(
+ self,
+ prompt: str,
+ print_log: bool = True,
+ prompt_tokens: list[int] | None = None,
+ ) -> tuple[str, list[float], list[int], int]:
+ """Generation with MTP (Multi-Token Prediction) speculative decoding."""
+ if prompt_tokens is None:
+ prompt_tokens = self.tokenizer.apply_chat_template(
+ [{"role": "user", "content": prompt}], add_generation_prompt=True
+ )
+
+ max_seq_len = self.config.max_seq_len
+ prompt_len = len(prompt_tokens)
+ total_len = min(max_seq_len, self.max_new_tokens + prompt_len)
+
+ tokens = torch.full(
+ (self.batch_size, total_len), -1, dtype=torch.long, device=self.default_device
+ )
+ tokens[0, :prompt_len] = torch.tensor(
+ prompt_tokens, dtype=torch.long, device=self.default_device
+ )
+
+ prefill_time_list = []
+ decode_time_list = []
+ decode_accepted_counts = []
+ cur_pos = 0
+
+ while cur_pos < prompt_len - 1:
+ draft_end = min(cur_pos + self.mtp_seq_len, prompt_len)
+ draft_tokens = tokens[0, cur_pos:draft_end].clone()
+ actual_token_count = draft_tokens.shape[0]
+
+ if actual_token_count < self.mtp_seq_len:
+ pad_token = draft_tokens[-1].item()
+ padding = torch.full(
+ (self.mtp_seq_len - actual_token_count,),
+ pad_token,
+ dtype=torch.long,
+ device=self.default_device,
+ )
+ draft_tokens = torch.cat([draft_tokens, padding])
+
+ draft_tokens = draft_tokens.reshape(1, self.mtp_seq_len).to(torch.int32)
+
+ mtp_extra_pos = cur_pos + self.mtp_seq_len
+ if mtp_extra_pos < prompt_len:
+ mtp_extra_token = int(tokens[0, mtp_extra_pos].item())
+ else:
+ mtp_extra_token = int(tokens[0, draft_end - 1].item())
+ self.decode_layer.set_prefill_mtp_extra_token(mtp_extra_token)
+
+ self.decode_layer.set_prefill_valid_tokens(actual_token_count)
+
+ start_time = time.time()
+ self.decode_layer.forward(draft_tokens, with_mtp=True)
+ end_time = time.time()
+ prefill_time_list.append(end_time - start_time)
+
+ cur_pos += actual_token_count
+
+ cur_pos = prompt_len - 1
+ self.set_cur_pos(prompt_len - 1)
+
+ self.decode_layer.set_prefill_valid_tokens(0)
+
+ finished = False
+ while cur_pos < total_len - 1 and not finished:
+ if cur_pos == prompt_len - 1:
+ last_token = tokens[0, prompt_len - 1].item()
+ draft_tokens = torch.full(
+ (self.mtp_seq_len,),
+ last_token,
+ dtype=torch.long,
+ device=self.default_device,
+ )
+ draft_tokens = draft_tokens.reshape(1, self.mtp_seq_len).to(torch.int32)
+ else:
+ draft_tokens = self.decode_layer.get_next_draft_tokens(0).reshape(
+ 1, self.mtp_seq_len
+ )
+
+ start_time = time.time()
+ self.decode_layer.forward(draft_tokens, with_mtp=True)
+ end_time = time.time()
+ decode_time_list.append(end_time - start_time)
+
+ num_accepted = self.decode_layer.get_num_accepted(0)
+ predicted_tokens = self.decode_layer.get_predicted_tokens(0).flatten()
+ decode_accepted_counts.append(num_accepted)
+
+ num_output_tokens = num_accepted
+ for i in range(num_output_tokens):
+ if cur_pos + 1 + i >= total_len:
+ break
+ new_token = int(predicted_tokens[i].item())
+ tokens[0, cur_pos + 1 + i] = new_token
+
+ if cur_pos + 1 + i >= prompt_len and print_log:
+ decoded_text = self.tokenizer.decode([new_token], skip_special_tokens=True)
+ print(decoded_text, end="", flush=True)
+
+ if new_token == self.eos_id:
+ finished = True
+ break
+
+ cur_pos += num_accepted
+
+ if print_log:
+ print("\n")
+ total_tokens = sum(decode_accepted_counts)
+ logger.info(f"--Number of forward calls (decode): {len(decode_accepted_counts)}")
+ logger.info(f"--Total tokens generated: {total_tokens}")
+ if len(decode_accepted_counts) > 0:
+ avg_accepted = sum(decode_accepted_counts) / len(decode_accepted_counts)
+ min_accepted = min(decode_accepted_counts)
+ max_accepted = max(decode_accepted_counts)
+ logger.info(
+ f"--Accepted tokens per call: mean={avg_accepted:.2f}, "
+ f"min={min_accepted}, max={max_accepted}"
+ )
+
+ if decode_time_list:
+ total_decode_time = sum(decode_time_list)
+ effective_tps = total_tokens / total_decode_time if total_decode_time > 0 else 0
+ avg_time_ms = total_decode_time / len(decode_time_list) * 1000
+ logger.info(f"--Avg forward time: {avg_time_ms:.2f}ms")
+ logger.info(f"--Effective TPS (with MTP): {effective_tps:.2f} tokens/s")
+
+ print("\n")
+
+ self.decode_layer.reset_sequence()
+
+ completion_tokens = []
+ for _, toks in enumerate(tokens.tolist()):
+ toks = toks[prompt_len : prompt_len + self.max_new_tokens]
+ toks = [t for t in toks if t != -1]
+ if self.eos_id in toks:
+ toks = toks[: toks.index(self.eos_id)]
+ completion_tokens.append(toks)
+
+ decoded_tokens = self.tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
+
+ return (
+ f"{decoded_tokens[0]}\n" if decoded_tokens else "",
+ decode_time_list,
+ decode_accepted_counts,
+ prompt_len,
+ )
+
+ def inject_cache(
+ self,
+ layer_caches: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
+ start_pos: int = 0,
+ end_pos: int | None = None,
+ ) -> None:
+ """Inject external cache data into TileRT.
+
+ This API allows injecting pre-computed KI/KV/PE cache data from an external
+ prefill system, enabling prefill-decode disaggregation.
+
+ Args:
+ layer_caches: List of (ki, kv, pe) tuples for each layer (0 to NUM_LAYERS-1).
+ Each tensor should be BF16 with shape [seqlen, dim] where:
+ - ki: [seqlen, 128] - compressed key
+ - kv: [seqlen, 512] - compressed key-value
+ - pe: [seqlen, 64] - position encoding cache
+ start_pos: Start position in cache to write (0-indexed). Defaults to 0.
+ end_pos: End position in cache (exclusive). If None, uses seqlen from tensors.
+
+ Example:
+ >>> # Load cache from external prefill system
+ >>> layer_caches = [] # List of 61 (ki, kv, pe) tuples
+ >>> for layer_id in range(61):
+ ... ki = load_ki_for_layer(layer_id) # [seqlen, 128] bf16
+ ... kv = load_kv_for_layer(layer_id) # [seqlen, 512] bf16
+ ... pe = load_pe_for_layer(layer_id) # [seqlen, 64] bf16
+ ... layer_caches.append((ki, kv, pe))
+ >>> generator.inject_cache(layer_caches, start_pos=0)
+ >>> generator.set_cur_pos(seqlen) # Set RoPE position
+ >>> # Continue generation from cache
+ """
+ num_layers = len(layer_caches)
+ if num_layers == 0:
+ logger.warning("inject_cache called with empty layer_caches")
+ return
+
+ first_ki, _, _ = layer_caches[0]
+ seqlen = first_ki.size(0)
+ if end_pos is None:
+ end_pos = start_pos + seqlen
+
+ cache_len = end_pos - start_pos
+ logger.info(f"Injecting cache: {num_layers} layers, positions [{start_pos}, {end_pos})")
+
+ num_devices = self.decode_layer.num_devices
+
+ for device_id in range(num_devices):
+ _, caches, _, _ = self.decode_layer._get_device_result(device_id)
+
+ for layer_id, (ki, kv, pe) in enumerate(layer_caches):
+ if layer_id >= num_layers:
+ logger.warning(f"Layer index {layer_id} is out of bounds, skipping.")
+ break
+
+ base_idx = layer_id * 3
+
+ ki_src = ki[:cache_len].to(f"cuda:{device_id}")
+ kv_src = kv[:cache_len].to(f"cuda:{device_id}")
+ pe_src = pe[:cache_len].to(f"cuda:{device_id}")
+
+ caches[base_idx + 0][0, start_pos:end_pos, :].copy_(ki_src)
+ caches[base_idx + 1][0, start_pos:end_pos, :].copy_(kv_src)
+ caches[base_idx + 2][0, start_pos:end_pos, :].copy_(pe_src)
+
+ logger.info(f"Cache injection completed for {num_devices} devices")
+
+ def set_cur_pos(self, cur_pos: int) -> None:
+ """Set the current position for RoPE.
+
+ This should be called after inject_cache() to ensure the runtime position
+ matches the injected cache length, for correct RoPE position encoding
+ during continued generation.
+
+ Args:
+ cur_pos: The current sequence position (typically the length of prefilled tokens).
+
+ Example:
+ >>> generator.inject_cache(layer_caches, start_pos=0)
+ >>> generator.set_cur_pos(prefill_len) # Set position to prefill length
+ >>> # Now generate continues from the correct position
+ """
+ if self.with_mtp:
+ num_devices = self.decode_layer.num_devices
+ for device_id in range(num_devices):
+ intermediates, _, _, _ = self.decode_layer._get_device_result(device_id)
+ cur_pos_tensor = intermediates[Idx.CUR_POS]
+ cur_pos_tensor.fill_(cur_pos)
+ else:
+ torch.ops.tilert.dsa_show_hands_set_cur_pos(cur_pos)
+
+ def inject_last_hidden_state(self, last_hidden_state: torch.Tensor) -> None:
+ """Inject the last hidden state for MTP mode.
+
+ For MTP (Multi-Token Prediction), the MTP preprocess layer needs the
+ last hidden state from the main model's last token.
+
+ Args:
+ last_hidden_state: [hidden_size] or [1, hidden_size] BF16 tensor.
+ The hidden state of the last token from prefill.
+
+ Example:
+ >>> # After inject_cache, inject the last hidden state for MTP
+ >>> generator.inject_last_hidden_state(last_hidden_state)
+ >>> # Then set cur_pos and start generation
+ """
+ if not self.with_mtp:
+ logger.warning("inject_last_hidden_state called but with_mtp is False, skipping")
+ return
+
+ if last_hidden_state.dim() == 1:
+ last_hidden_state = last_hidden_state.unsqueeze(0)
+
+ num_devices = self.decode_layer.num_devices
+ for device_id in range(num_devices):
+ intermediates, _, _, _ = self.decode_layer._get_device_result(device_id)
+ lhs_tensor = intermediates[Idx.LAST_HIDDEN_STATES]
+ lhs_src = last_hidden_state.to(f"cuda:{device_id}")
+ lhs_tensor[0, 0, :].copy_(lhs_src.squeeze(0))
+
+ logger.info(f"Injected last_hidden_state to {num_devices} devices")
diff --git a/tilert/models/glm_5/_dsa_v32/model_args.py b/tilert/models/glm_5/_dsa_v32/model_args.py
new file mode 100644
index 0000000..441b684
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/model_args.py
@@ -0,0 +1,95 @@
+"""Model arguments and hyperparameters."""
+
+from dataclasses import dataclass
+from typing import Literal
+
+__all__ = [
+ "ModelArgs",
+]
+
+
+@dataclass
+class ModelArgs:
+ """
+ Data class for defining model arguments and hyperparameters.
+
+ Attributes:
+ arch_name (str): Architecture name.
+ max_batch_size (int): Maximum batch size.
+ max_seq_len (int): Maximum sequence length.
+ dtype (Literal["bf16", "fp8"]): Data type for computations.
+ scale_fmt (Optional[str]): Format for quantization scale.
+ vocab_size (int): Vocabulary size.
+ dim (int): Model dimension.
+ inter_dim (int): Intermediate dimension for MLP layers.
+ moe_inter_dim (int): Intermediate dimension for MoE layers.
+ n_layers (int): Number of transformer layers.
+ n_dense_layers (int): Number of dense layers in the model.
+ n_heads (int): Number of attention heads.
+ n_routed_experts (int): Number of routed experts for MoE layers.
+ n_shared_experts (int): Number of shared experts for MoE layers.
+ n_activated_experts (int): Number of activated experts in MoE layers.
+ n_expert_groups (int): Number of expert groups.
+ n_limited_groups (int): Number of limited groups for MoE routing.
+ score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing.
+ route_scale (float): Scaling factor for routing scores.
+ q_lora_rank (int): LoRA rank for query projections.
+ kv_lora_rank (int): LoRA rank for key-value projections.
+ qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings.
+ qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings.
+ v_head_dim (int): Dimension for value projections.
+ original_seq_len (Optional[int]): Original sequence length.
+ rope_theta (float): Base for rotary positional encoding.
+ rope_factor (Optional[float]): Scaling factor for extended sequence lengths.
+ beta_fast (Optional[int]): Fast beta correction factor.
+ beta_slow (Optional[int]): Slow beta correction factor.
+ mscale (float): Scaling factor for extended attention.
+ index_head_dim (int): Dimension for index head.
+ index_topk (int): Top-k for index head.
+ """
+
+ arch_name = "deepseek_v3_2"
+
+ max_batch_size: int = 1
+ max_seq_len: int = 160 * 1024
+ dtype: Literal["bf16", "fp8"] = "fp8"
+ scale_fmt: str | None = None
+
+ vocab_size: int = 129280
+ dim: int = 7168
+ inter_dim: int = 18432
+ moe_inter_dim: int = 2048
+ n_layers: int = 61
+ n_dense_layers: int = 3
+ n_heads: int = 128
+
+ n_routed_experts: int = 256
+ n_shared_experts: int = 1
+ n_activated_experts: int = 8
+ n_expert_groups: int = 8
+ n_limited_groups: int = 4
+ score_func: Literal["softmax", "sigmoid", "sqrtsoftplus"] = "softmax"
+ route_scale: float = 2.5
+
+ q_lora_rank: int = 1536
+ kv_lora_rank: int = 512
+ qk_nope_head_dim: int = 128
+ qk_rope_head_dim: int = 64
+ v_head_dim: int = 128
+
+ original_seq_len: int | None = 4096
+ rope_theta: float = 10000.0
+ rope_factor: float | None = 40
+ beta_fast: int | None = 32
+ beta_slow: int | None = 1
+ mscale: float = 1.0
+
+ index_n_heads: int = 64
+ index_head_dim: int = 128
+ index_topk: int = 2048
+
+ kv_cache_pad: int = 8
+
+ block_size: int = 128
+
+ eps: float = 1e-6
diff --git a/tilert/models/glm_5/_dsa_v32/modules/__init__.py b/tilert/models/glm_5/_dsa_v32/modules/__init__.py
new file mode 100644
index 0000000..937085b
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/modules/__init__.py
@@ -0,0 +1,11 @@
+"""DeepSeek v3.2 high-level Python modules (MLA, MLP, MTP, etc.)."""
+
+__all__ = [
+ "dsa",
+ "end2end",
+ "mla",
+ "mlp",
+ "moe",
+ "mtp",
+ "mtp_preprocess",
+]
diff --git a/tilert/models/glm_5/_dsa_v32/modules/dsa.py b/tilert/models/glm_5/_dsa_v32/modules/dsa.py
new file mode 100644
index 0000000..38a01c1
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/modules/dsa.py
@@ -0,0 +1,229 @@
+from typing import Any
+
+import torch
+
+from tilert.models.base import SerializableTileRTModule
+from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
+from tilert.models.glm_5._dsa_v32.modules.mlp import MlpBlock
+from tilert.models.glm_5._dsa_v32.modules.moe import MoeBlock
+from tilert.models.glm_5._dsa_v32.ops import RMSNormHeadProj
+from tilert.models.glm_5._dsa_v32.temp_var_indices import TEMP_VARS_SIZE, Idx
+
+
+class Dsa(SerializableTileRTModule):
+ """DSA module."""
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ device_id: int,
+ num_devices: int,
+ cached_ffn_ops: list | None = None,
+ ):
+ super().__init__(
+ model_args=model_args,
+ device_id=device_id,
+ num_devices=num_devices,
+ remove_selected=True,
+ )
+ from tilert.models.glm_5._dsa_v32.modules.mla_v2 import (
+ PureMlaV2,
+ SparseSelectMlaV2,
+ )
+
+ mla_cls = SparseSelectMlaV2 if device_id == 0 else PureMlaV2
+ mla_kwargs: dict = {}
+
+ dev = f"cuda:{device_id}"
+ n_peers = num_devices - 1
+ if device_id == 0:
+ self.v2_peer_bufs = torch.zeros(n_peers, dtype=torch.int64, device=dev)
+ self.v2_partial_buf = torch.zeros(
+ model_args.max_batch_size, 4, model_args.dim, dtype=torch.bfloat16, device=dev
+ )
+ mla_kwargs = {
+ "peer_bufs": self.v2_peer_bufs,
+ "partial_buf": self.v2_partial_buf,
+ }
+ else:
+ max_seq_len = getattr(model_args, "num_mtp", 3) + 1
+ topk = model_args.index_topk
+ self.v2_ll_buf = torch.zeros(max_seq_len * topk * 2, dtype=torch.int32, device=dev)
+ mla_kwargs = {"ll_buf": self.v2_ll_buf}
+
+ mla_num_devices: int | None = None
+ if device_id != 0:
+ mla_num_devices = num_devices - 1
+
+ if cached_ffn_ops is not None:
+ assert (
+ len(cached_ffn_ops) == model_args.n_layers
+ ), f"Expected {model_args.n_layers} cached FFN ops, got {len(cached_ffn_ops)}"
+
+ for layer_idx in range(model_args.n_layers):
+ ffn_op = cached_ffn_ops[layer_idx] if cached_ffn_ops else None
+ if layer_idx < model_args.n_dense_layers:
+ block = MlpBlock(
+ model_args=model_args,
+ device_id=device_id,
+ num_devices=num_devices,
+ mla_cls=mla_cls,
+ mla_num_devices=mla_num_devices,
+ mla_kwargs=mla_kwargs,
+ mlp=ffn_op,
+ )
+ else:
+ block = MoeBlock(
+ model_args=model_args,
+ device_id=device_id,
+ num_devices=num_devices,
+ mla_cls=mla_cls,
+ mla_num_devices=mla_num_devices,
+ mla_kwargs=mla_kwargs,
+ moe=ffn_op,
+ )
+ self.register_op(block, prefix=f"layer_{layer_idx}_", suffix=f"_dev_{device_id}")
+
+ self.register_op(
+ RMSNormHeadProj(model_args=model_args, device_id=device_id, num_devices=num_devices),
+ prefix=f"layer_{model_args.n_layers}_",
+ suffix=f"_dev_{device_id}",
+ retain_weights=True,
+ )
+
+ self.embed_tokens_weight = None
+ self.freqs_cis = None
+
+ def init_tilert_weights(self, state_dicts: dict[str, torch.Tensor]) -> None:
+ super().init_tilert_weights(state_dicts)
+ self.embed_tokens_weight = state_dicts["model.embed_tokens.weight"]
+ self.freqs_cis = state_dicts["freqs_cis"]
+
+ def get_weights_list(self) -> list[torch.Tensor]:
+ return [*super().get_weights_list(), self.embed_tokens_weight, self.freqs_cis]
+
+ def get_temp_vars(
+ self, batch_size: int, seq_len: int, extra_args: dict[str, Any] | None = None
+ ) -> list[torch.Tensor]:
+ bf16_desc = {"dtype": torch.bfloat16, "device": f"cuda:{self.device_id}"}
+ fp32_desc = {"dtype": torch.float32, "device": f"cuda:{self.device_id}"}
+ int32_desc = {"dtype": torch.int32, "device": f"cuda:{self.device_id}"}
+ int64_desc = {"dtype": torch.int64, "device": f"cuda:{self.device_id}"}
+ fp8_desc = {"dtype": torch.float8_e4m3fn, "device": f"cuda:{self.device_id}"}
+
+ assert extra_args is not None
+ temperature = extra_args["temperature"]
+ top_p = extra_args["top_p"]
+ top_k = extra_args["top_k"]
+ use_topp = extra_args["use_topp"]
+
+ dim = self.model_args.dim
+ batch_seq = (batch_size, seq_len)
+ q_lora_rank = self.model_args.q_lora_rank
+ kv_lora_rank = self.model_args.kv_lora_rank
+ qk_nope_head_dim = self.model_args.qk_nope_head_dim
+ if self.device_id != 0:
+ from tilert.models.glm_5._dsa_v32.ops.rmsnorm_projq_wqb import (
+ RmsnormProjqWqbWeightsConverter,
+ )
+
+ qk_head_dim = self.model_args.qk_nope_head_dim + self.model_args.qk_rope_head_dim
+ n_local_heads = RmsnormProjqWqbWeightsConverter._compute_n_local_heads(
+ self.model_args.n_heads, self.num_devices - 1, qk_head_dim
+ )
+ else:
+ n_local_heads = self.model_args.n_heads // self.num_devices
+ qk_rope_head_dim = self.model_args.qk_rope_head_dim
+ index_head_dim = self.model_args.index_head_dim
+ v_head_dim = self.model_args.v_head_dim
+ n_index_heads = self.model_args.index_n_heads
+ max_seq_len = self.model_args.max_seq_len
+ index_topk = self.model_args.index_topk
+ n_routed_experts = self.model_args.n_routed_experts
+ n_activated_experts = self.model_args.n_activated_experts
+ n_total_experts = self.model_args.n_activated_experts + self.model_args.n_shared_experts
+ moe_inter_dim = self.model_args.moe_inter_dim // self.num_devices
+ vocab_size = self.model_args.vocab_size // self.num_devices
+
+ temp_vars: list[torch.Tensor | None] = [None] * TEMP_VARS_SIZE
+
+ temp_vars[Idx.Q] = torch.zeros(*batch_seq, q_lora_rank, **bf16_desc)
+ temp_vars[Idx.KV] = torch.zeros(*batch_seq, kv_lora_rank, **bf16_desc)
+ temp_vars[Idx.KI] = torch.zeros(*batch_seq, index_head_dim, **bf16_desc)
+ temp_vars[Idx.Q_NOPE_DOWN] = torch.zeros(
+ *batch_seq, n_local_heads, qk_nope_head_dim, **bf16_desc
+ )
+ temp_vars[Idx.Q_PE] = torch.zeros(*batch_seq, n_local_heads, qk_rope_head_dim, **bf16_desc)
+ temp_vars[Idx.IQ] = torch.zeros(*batch_seq, n_index_heads, index_head_dim, **bf16_desc)
+ temp_vars[Idx.IQ_RT] = torch.zeros(*batch_seq, n_index_heads, index_head_dim, **bf16_desc)
+ temp_vars[Idx.IDX_SCORES] = torch.zeros(*batch_seq, n_index_heads, **bf16_desc)
+ temp_vars[Idx.IDX_LOGITS] = torch.zeros(
+ *batch_seq, max_seq_len + self.model_args.kv_cache_pad, **fp32_desc
+ )
+ temp_vars[Idx.IDX_SELECTS] = torch.zeros(*batch_seq, index_topk, **int32_desc)
+ temp_vars[Idx.Q_NOPE] = torch.zeros(*batch_seq, n_local_heads, kv_lora_rank, **bf16_desc)
+ temp_vars[Idx.O] = torch.zeros(*batch_seq, n_local_heads, kv_lora_rank, **bf16_desc)
+ temp_vars[Idx.O_ACC] = torch.zeros(*batch_seq, n_local_heads, 32, kv_lora_rank, **fp32_desc)
+ temp_vars[Idx.O_LSE] = torch.empty(*batch_seq, n_local_heads, **fp32_desc)
+ temp_vars[Idx.O_LSE_ACC] = torch.empty(*batch_seq, n_local_heads, 32, **fp32_desc)
+ temp_vars[Idx.PROJ_O] = torch.zeros(*batch_seq, n_local_heads, v_head_dim, **bf16_desc)
+ temp_vars[Idx.UNPROJ_O] = torch.zeros(*batch_seq, dim, **bf16_desc)
+ temp_vars[Idx.SCORES] = torch.zeros(*batch_seq, n_routed_experts, **fp32_desc)
+ temp_vars[Idx.X_MLP_IN] = torch.zeros(*batch_seq, dim, **bf16_desc)
+ exp_up_gate = torch.zeros(*batch_seq, n_total_experts, moe_inter_dim, **bf16_desc)
+ temp_vars[Idx.UP_GATE] = exp_up_gate
+ temp_vars[Idx.SEL_PROBS] = torch.zeros(*batch_seq, n_activated_experts, **fp32_desc)
+ temp_vars[Idx.SEL_INDICES] = torch.zeros(*batch_seq, n_activated_experts, **int32_desc)
+ temp_vars[Idx.EXP_OUT] = torch.zeros(*batch_seq, dim, **bf16_desc)
+ temp_vars[Idx.X_RMSNORM] = torch.zeros(*batch_seq, dim, **bf16_desc)
+ temp_vars[Idx.LOGITS_OUT] = torch.zeros(*batch_seq, vocab_size, **fp32_desc)
+ temp_vars[Idx.TOKEN_OUT] = torch.zeros(*batch_seq, 1, **int32_desc)
+
+ temp_vars[Idx.EMBEDDING_RMSNORM] = torch.zeros(*batch_seq, dim, **bf16_desc)
+ temp_vars[Idx.HIDDEN_RMSNORM] = torch.zeros(*batch_seq, dim, **bf16_desc)
+ temp_vars[Idx.EH_PROJ] = torch.zeros(*batch_seq, dim, **bf16_desc)
+ temp_vars[Idx.X_TENSOR] = torch.zeros(*batch_seq, dim, **bf16_desc)
+ temp_vars[Idx.ROPE_FREQS] = torch.zeros(*batch_seq, qk_rope_head_dim, **fp32_desc)
+ temp_vars[Idx.CUR_POS] = torch.zeros(batch_size, **int32_desc)
+ temp_vars[Idx.TOKEN_ID] = torch.zeros(*batch_seq, 1, **int32_desc)
+ temp_vars[Idx.LAST_HIDDEN_STATES] = torch.zeros(*batch_seq, dim, **bf16_desc)
+
+ temp_vars[Idx.DRAFT_TOKENS] = torch.zeros(*batch_seq, **int32_desc)
+ temp_vars[Idx.PREDICTED_TOKENS] = torch.zeros(*batch_seq, 1, **int32_desc)
+ temp_vars[Idx.PREDICTED_HIDDEN] = torch.zeros(*batch_seq, dim, **bf16_desc)
+ temp_vars[Idx.ACCEPTED_TOKENS] = torch.zeros(batch_size, **int32_desc)
+ temp_vars[Idx.NEXT_DRAFT_TOKENS] = torch.zeros(*batch_seq, **int32_desc)
+
+ temp_vars[Idx.X_QUANT] = torch.zeros(*batch_seq, dim, **fp8_desc)
+ temp_vars[Idx.X_SCALE] = torch.zeros(
+ *batch_seq, dim // self.model_args.block_size, **fp32_desc
+ )
+ temp_vars[Idx.MOE_UP_GATE] = torch.zeros_like(exp_up_gate)
+
+ temp_vars[Idx.IDX_SEL_WS] = torch.zeros(*batch_seq, (200 * 1024 + 260), **int32_desc)
+
+ temp_vars[Idx.MTP0_TOKEN_OUT] = torch.zeros(*batch_seq, 1, **int32_desc)
+ temp_vars[Idx.MTP1_TOKEN_OUT] = torch.zeros(*batch_seq, 1, **int32_desc)
+ temp_vars[Idx.MTP0_EXP_OUT] = torch.zeros(*batch_seq, dim, **bf16_desc)
+
+ temp_vars[Idx.SAMPLING_SEED] = torch.zeros(*batch_seq, **int64_desc)
+ temp_vars[Idx.SAMPLING_POSITIONS] = torch.zeros(*batch_seq, **int64_desc)
+ temp_vars[Idx.SAMPLING_CONFIG] = torch.tensor(
+ [temperature, top_p, top_k, use_topp], **fp32_desc
+ )
+ temp_vars[Idx.TOP_P_SCORES] = torch.zeros(*batch_seq, **fp32_desc)
+ temp_vars[Idx.TOP_P_DEBUG] = torch.zeros(*batch_seq, vocab_size, **fp32_desc)
+
+ temp_vars[Idx.LORA_SLOT_ID] = torch.zeros(1, **int32_desc)
+ temp_vars[Idx.LORA_RANK] = torch.zeros(1, **int32_desc)
+
+ max_top_n = 256
+ temp_vars[Idx.TOP_N_LOG_PROBS] = torch.zeros(*batch_seq, max_top_n, **fp32_desc)
+ temp_vars[Idx.TOP_N_INDICES] = torch.zeros(*batch_seq, max_top_n, **int32_desc)
+ temp_vars[Idx.LOGPROBS_FLAG] = torch.zeros(1, **int32_desc)
+
+ for i, t in enumerate(temp_vars):
+ if t is None:
+ raise RuntimeError(f"temp_vars[{i}] ({Idx(i).name}) was not initialized")
+
+ return temp_vars # type: ignore[return-value]
diff --git a/tilert/models/glm_5/_dsa_v32/modules/end2end.py b/tilert/models/glm_5/_dsa_v32/modules/end2end.py
new file mode 100644
index 0000000..6b4e69c
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/modules/end2end.py
@@ -0,0 +1,703 @@
+"""DSA show hands for deepseek v3.2."""
+
+import json
+import os
+import sys
+import threading
+import time
+from typing import Any
+
+import torch
+from safetensors import safe_open
+from safetensors.torch import load_file
+
+from tilert import logger
+from tilert.models.base import TileRTModule
+from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
+from tilert.models.glm_5._dsa_v32.modules.dsa import Dsa
+from tilert.models.glm_5._dsa_v32.modules.mtp import MTP
+from tilert.models.glm_5._dsa_v32.temp_var_indices import Idx, validate_temp_vars_layout
+from tilert.models.utils import precompute_freqs_cis
+from tilert.utils import get_profile_log_tensor
+
+__all__ = ["ShowHandsDSALayer", "_extract_ffn_ops", "_get_moe_weight_keys"]
+
+
+DeviceResult = tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor], torch.Tensor]
+
+
+def _mark_weights_initialized(module: TileRTModule) -> None:
+ """Recursively mark a module and all sub-ops as having initialized tilert weights."""
+ module.is_tilert_weights_init = True
+ if hasattr(module, "exec_seq"):
+ for op in module.exec_seq:
+ _mark_weights_initialized(op)
+
+
+def _extract_ffn_ops(dsa: "Dsa") -> list:
+ """Extract Moe/Mlp op objects from a Dsa's layer blocks.
+
+ Returns a list of length n_layers where each element is a Moe or Mlp instance.
+ """
+ from tilert.models.glm_5._dsa_v32.modules.mlp import MlpBlock
+ from tilert.models.glm_5._dsa_v32.modules.moe import MoeBlock
+
+ ffn_ops = []
+ for block in dsa.exec_seq:
+ if isinstance(block, MoeBlock):
+ op = block.moe
+ _mark_weights_initialized(op)
+ ffn_ops.append(op)
+ elif isinstance(block, MlpBlock):
+ op = block.mlp
+ _mark_weights_initialized(op)
+ ffn_ops.append(op)
+
+ assert (
+ len(ffn_ops) == dsa.model_args.n_layers
+ ), f"Expected {dsa.model_args.n_layers} FFN ops, got {len(ffn_ops)}"
+ return ffn_ops
+
+
+def _get_moe_weight_keys(dsa: "Dsa") -> set[str]:
+ """Get state_dict keys that belong exclusively to MOE/MLP ops in this Dsa."""
+ from tilert.models.glm_5._dsa_v32.modules.mlp import MlpBlock
+ from tilert.models.glm_5._dsa_v32.modules.moe import MoeBlock
+
+ moe_keys: set[str] = set()
+ mla_keys: set[str] = set()
+ for block, prefix, suffix in zip(dsa.exec_seq, dsa.prefix_seq, dsa.suffix_seq):
+ if isinstance(block, (MoeBlock, MlpBlock)):
+ ffn = block.moe if isinstance(block, MoeBlock) else block.mlp
+ for alias in ffn.get_tilert_weights_alias():
+ moe_keys.add(f"{prefix}{alias}{suffix}")
+ for alias in block.mla.get_tilert_weights_alias():
+ mla_keys.add(f"{prefix}{alias}{suffix}")
+ return moe_keys - mla_keys
+
+
+def dsa_show_hands_prepare_money(
+ params: list[torch.Tensor],
+ temp_vars: list[torch.Tensor],
+ cache_vars: list[torch.Tensor],
+ profile_logs: torch.Tensor,
+ forward_max_seq_len: int,
+ with_mtp: bool = False,
+ is_glm5: bool = False,
+) -> Any:
+ """Prepare money for show hands"""
+ mtp_flag = "_mtp_e2e" if with_mtp else ""
+ glm5_flag = "_glm5" if is_glm5 else ""
+ func_name = f"dsa{mtp_flag}_show_hands_prepare_money{glm5_flag}"
+ if mtp_flag:
+ return getattr(torch.ops.tilert, func_name)(params, temp_vars, cache_vars, profile_logs)
+ return getattr(torch.ops.tilert, func_name)(
+ params, temp_vars, cache_vars, profile_logs, forward_max_seq_len
+ )
+
+
+def dsa_show_hands(token_id: torch.Tensor, with_mtp: bool = False, is_glm5: bool = False) -> Any:
+ """Show hands with native MT"""
+ mtp_flag = "_mtp_e2e" if with_mtp else ""
+ glm5_flag = "_glm5" if is_glm5 else ""
+ func_name = f"dsa{mtp_flag}_show_hands{glm5_flag}"
+ return getattr(torch.ops.tilert, func_name)(token_id)
+
+
+def dsa_show_hands_reset(with_mtp: bool = False, is_glm5: bool = False) -> Any:
+ """Reset show one hand"""
+ mtp_flag = "_mtp_e2e" if with_mtp else ""
+ glm5_flag = "_glm5" if is_glm5 else ""
+ func_name = f"dsa{mtp_flag}_show_hands_reset{glm5_flag}"
+ return getattr(torch.ops.tilert, func_name)()
+
+
+def dsa_show_hands_go_home(with_mtp: bool = False, is_glm5: bool = False) -> Any:
+ """Go home"""
+ mtp_flag = "_mtp_e2e" if with_mtp else ""
+ glm5_flag = "_glm5" if is_glm5 else ""
+ func_name = f"dsa{mtp_flag}_show_hands_go_home{glm5_flag}"
+ return getattr(torch.ops.tilert, func_name)()
+
+
+def dsa_show_hands_set_sampling_seed(
+ seed: int, with_mtp: bool = False, is_glm5: bool = False
+) -> Any:
+ """Set the sampling seed (request-level, fixed for the entire request).
+
+ Args:
+ seed: The sampling seed value.
+ """
+ mtp_flag = "_mtp_e2e" if with_mtp else ""
+ glm5_flag = "_glm5" if is_glm5 else ""
+ func_name = f"dsa{mtp_flag}_show_hands_set_sampling_seed{glm5_flag}"
+ return getattr(torch.ops.tilert, func_name)(seed)
+
+
+def dsa_mtp_e2e_show_hands_set_prefill_valid_tokens(
+ num_valid_tokens: int, is_glm5: bool = False
+) -> Any:
+ """Set the number of valid (non-padding) tokens for prefill mode.
+
+ This controls how many tokens are copied from draft_tokens to predicted_tokens
+ during prefill. Should be called before forward() when the chunk has padding.
+
+ Args:
+ num_valid_tokens: Number of valid tokens in the chunk (1-4).
+ """
+ mtp_flag = "_mtp_e2e"
+ glm5_flag = "_glm5" if is_glm5 else ""
+ func_name = f"dsa{mtp_flag}_show_hands_set_prefill_valid_tokens{glm5_flag}"
+ return getattr(torch.ops.tilert, func_name)(num_valid_tokens)
+
+
+def dsa_mtp_e2e_show_hands_set_prefill_mtp_extra_token(token: int, is_glm5: bool = False) -> Any:
+ """Set the extra token for MTP[0] shifted input during prefill.
+
+ Args:
+ token: The extra prompt token id (int32).
+ """
+ mtp_flag = "_mtp_e2e"
+ glm5_flag = "_glm5" if is_glm5 else ""
+ func_name = f"dsa{mtp_flag}_show_hands_set_prefill_mtp_extra_token{glm5_flag}"
+ return getattr(torch.ops.tilert, func_name)(token)
+
+
+class ShowHandsDSALayer:
+ """Show hands DSA for deepseek v3.2."""
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ model_path: str = "",
+ with_weight_conversion: bool = True,
+ with_mtp: bool = False,
+ temperature: float = 1.0,
+ top_p: float = 0.9,
+ top_k: int = 256,
+ use_topp: bool = False,
+ ) -> None:
+ validate_temp_vars_layout()
+ print(f"Model args: {model_args.arch_name}")
+ for k_arg, v_arg in model_args.__dict__.items():
+ print(f" - {k_arg}: {v_arg}")
+ self.model_args = model_args
+ self.is_glm5 = self.model_args.arch_name == "glm_5"
+ assert self.model_args.arch_name in ["deepseek_v3_2", "glm_5"]
+
+ self.num_devices = 8
+ self.forward_max_seq_len = 4
+
+ self.model_path = model_path
+ self.with_weight_conversion = with_weight_conversion
+ self.with_mtp = with_mtp
+
+ self.multi_devices_results: list[DeviceResult | None] = [None] * torch.cuda.device_count()
+ self._dsa_objects: list[Dsa | None] = [None] * torch.cuda.device_count()
+
+ self.temperature = temperature
+ self.top_p = top_p
+ self.top_k = top_k
+ self.use_topp = use_topp
+
+ def _gen_freqs_cis(self) -> torch.Tensor:
+ freqs_cis = precompute_freqs_cis(self.model_args)
+ return torch.view_as_real(freqs_cis).reshape(freqs_cis.shape[0], -1)
+
+ def load_device_weights(
+ self,
+ model_path: str,
+ device_id: int,
+ extra_keys: list,
+ skip_keys: set[str] | None = None,
+ ) -> dict[str, torch.Tensor]:
+ index_file = "model.safetensors.index.json"
+ with open(os.path.join(model_path, index_file), encoding="utf-8") as f:
+ weights_index = json.load(f)
+ weight_file_map = weights_index["weight_map"]
+
+ weights_list = [_k for _k in weight_file_map.keys() if _k.endswith(f"dev_{device_id}")]
+ weights_list = [*weights_list, *extra_keys]
+
+ if skip_keys:
+ weights_list = [k for k in weights_list if k not in skip_keys]
+
+ target_files = set()
+ for weight_key in weights_list:
+ weight_file = weight_file_map[weight_key]
+ target_files.add(weight_file)
+
+ state_dicts = {}
+ weights_set = set(weights_list)
+ for weight_file in target_files:
+ filepath = os.path.join(model_path, weight_file)
+ if skip_keys:
+ logger.info(
+ f"Selectively loading weights from {weight_file} for device {device_id}"
+ )
+ with safe_open(filepath, framework="pt", device=f"cuda:{device_id}") as f:
+ for key in f.keys():
+ if key in weights_set:
+ state_dicts[key] = f.get_tensor(key)
+ torch.cuda.empty_cache()
+ else:
+ logger.info(f"Loading weights from {weight_file} for device {device_id}")
+ state_dict = load_file(filepath, device=f"cuda:{device_id}")
+ state_dicts.update(state_dict)
+ del state_dict
+ torch.cuda.empty_cache()
+
+ state_dicts["freqs_cis"] = self._gen_freqs_cis().to(device_id)
+ return state_dicts
+
+ def update_sampling_config(
+ self, temperature: float, top_p: float, top_k: int, use_topp: bool = True
+ ) -> None:
+ """Update sampling config, re-capturing CUDA graphs if parameters changed."""
+ new_config = (temperature, top_p, top_k, use_topp)
+ current_config = (self.temperature, self.top_p, self.top_k, self.use_topp)
+ if new_config == current_config:
+ return
+
+ print(
+ f"Recapturing CUDA graphs: "
+ f"temperature={temperature}, top_p={top_p}, top_k={top_k}, use_topp={use_topp}"
+ )
+
+ if self.with_mtp:
+ dsa_show_hands_go_home(True, self.is_glm5)
+ dsa_show_hands_go_home(False, self.is_glm5)
+ else:
+ dsa_show_hands_go_home(False, self.is_glm5)
+
+ self.temperature = temperature
+ self.top_p = top_p
+ self.top_k = top_k
+ self.use_topp = use_topp
+
+ for device_id in range(self.num_devices):
+ result = self.multi_devices_results[device_id]
+ if result is not None:
+ intermediates = result[0]
+ intermediates[Idx.SAMPLING_CONFIG].copy_(
+ torch.tensor(
+ [temperature, top_p, float(top_k), 1.0 if use_topp else 0.0],
+ dtype=torch.float32,
+ device=f"cuda:{device_id}",
+ )
+ )
+
+ for device_id in range(self.num_devices):
+ with torch.cuda.device(device_id):
+ intermediates, caches, params, profile_logs = self._get_device_result(device_id)
+ dsa_show_hands_prepare_money(
+ params,
+ intermediates,
+ caches,
+ profile_logs,
+ self.forward_max_seq_len,
+ self.with_mtp,
+ self.is_glm5,
+ )
+ if self.with_mtp:
+ dsa_show_hands_prepare_money(
+ params[: self._base_params_count],
+ intermediates,
+ caches[: self._base_caches_count],
+ profile_logs,
+ self.forward_max_seq_len,
+ False,
+ self.is_glm5,
+ )
+
+ @staticmethod
+ def tot_size_in_bytes_aligned(temp_vars: list[torch.Tensor], aligned_size: int) -> int:
+ tot_size: int = 0
+ for param in temp_vars:
+ aligned_param_size = (param.nbytes + aligned_size - 1) // aligned_size * aligned_size
+ tot_size += aligned_param_size
+ return tot_size
+
+ def generate_params_with_continuous_storage(
+ self, temp_vars: list[torch.Tensor], device: torch.device, aligned_size: int = 1024
+ ) -> list[torch.Tensor]:
+ tot_size = self.tot_size_in_bytes_aligned(temp_vars, aligned_size)
+ cloned_params = []
+ large_tensor = torch.zeros(tot_size, device=device, dtype=torch.uint8)
+ offset = 0
+ for param in temp_vars:
+ aligned_param_size = (param.nbytes + aligned_size - 1) // aligned_size * aligned_size
+ cloned_params.append(
+ large_tensor[offset : offset + param.nbytes].view(param.dtype).view(param.shape)
+ )
+ offset += aligned_param_size
+ return cloned_params
+
+ def _init_weights(
+ self,
+ model_path: str | None,
+ cached_ffn_ops_per_device: dict[int, list] | None = None,
+ skip_keys_per_device: dict[int, set[str]] | None = None,
+ ) -> None:
+ """Load the model weights from the given path or generate random weights.
+
+ Args:
+ model_path: Path to the model weights directory.
+ cached_ffn_ops_per_device: Optional dict mapping device_id to cached FFN ops.
+ When provided, these ops are injected into the Dsa and their weights
+ are not re-loaded from disk.
+ skip_keys_per_device: Optional dict mapping device_id to safetensors keys
+ to skip during loading. Used together with cached_ffn_ops_per_device.
+ """
+ self._v2_p2p: dict = {}
+
+ def __load_weights(device_id: int, model_path: str | None) -> None:
+ intermediates: list[torch.Tensor] = []
+ caches: list[torch.Tensor] = []
+ params: list[torch.Tensor] = []
+ state_dicts = {}
+ start_time = time.time()
+ with torch.cuda.device(device_id):
+ assert model_path is not None
+ skip_keys = (
+ skip_keys_per_device.get(device_id)
+ if skip_keys_per_device is not None
+ else None
+ )
+ state_dicts = self.load_device_weights(
+ model_path,
+ device_id,
+ [
+ "model.embed_tokens.weight",
+ f"layer_{self.model_args.n_layers}_lm_head.weight_dev_{device_id}",
+ f"layer_{self.model_args.n_layers}_model.norm.weight_dev_{device_id}",
+ ],
+ skip_keys=skip_keys,
+ )
+
+ cached_ffn_ops = (
+ cached_ffn_ops_per_device.get(device_id)
+ if cached_ffn_ops_per_device is not None
+ else None
+ )
+ dsa = Dsa(
+ self.model_args,
+ device_id,
+ self.num_devices,
+ cached_ffn_ops=cached_ffn_ops,
+ )
+ dsa.init_tilert_weights(state_dicts)
+ self._dsa_objects[device_id] = dsa
+ params.extend(dsa.get_weights_list())
+ caches.extend(dsa.get_cache_vars())
+
+ if device_id == 0:
+ self._v2_p2p[device_id] = {
+ "peer_bufs": dsa.v2_peer_bufs,
+ }
+ else:
+ self._v2_p2p[device_id] = {
+ "ll_buf": dsa.v2_ll_buf,
+ }
+ intermediates.extend(
+ self.generate_params_with_continuous_storage(
+ dsa.get_temp_vars(
+ 1,
+ self.forward_max_seq_len,
+ {
+ "temperature": self.temperature,
+ "top_p": self.top_p,
+ "top_k": self.top_k,
+ "use_topp": self.use_topp,
+ },
+ ),
+ device_id,
+ )
+ )
+
+ sampling_config = intermediates[Idx.SAMPLING_CONFIG]
+ sampling_config.copy_(
+ torch.tensor(
+ [
+ self.temperature,
+ self.top_p,
+ float(self.top_k),
+ 1.0 if self.use_topp else 0.0,
+ ],
+ dtype=torch.float32,
+ device=device_id,
+ )
+ )
+
+ base_params_count = len(params)
+ base_caches_count = len(caches)
+
+ if self.with_mtp:
+ from tilert.models.glm_5._dsa_v32.modules.mla_v2 import (
+ PureMlaV2,
+ SparseSelectMlaV2,
+ )
+
+ mtp_kwargs: dict = {}
+ mtp_kwargs["mla_cls"] = SparseSelectMlaV2 if device_id == 0 else PureMlaV2
+ mtp_kwargs["mla_num_devices"] = 1 if device_id == 0 else self.num_devices - 1
+ if device_id == 0:
+ mtp_kwargs["mla_kwargs"] = {
+ "peer_bufs": dsa.v2_peer_bufs,
+ }
+ else:
+ mtp_kwargs["mla_kwargs"] = {"ll_buf": dsa.v2_ll_buf}
+ mtp = MTP(self.model_args, device_id, self.num_devices, **mtp_kwargs)
+ mtp.init_tilert_weights(state_dicts)
+ params.extend(mtp.get_weights_list())
+ caches.extend(mtp.get_cache_vars())
+ logger.info(f"Loaded real MTP weights for device {device_id}")
+
+ profile_logs = get_profile_log_tensor(device=device_id, num_max_insts=65536)
+ result = (intermediates, caches, params, profile_logs)
+ self.multi_devices_results[device_id] = result
+ self._base_params_count = base_params_count
+ self._base_caches_count = base_caches_count
+
+ del state_dicts
+ torch.cuda.empty_cache()
+ elapsed_time = time.time() - start_time
+ minutes = int(elapsed_time // 60)
+ seconds = int(elapsed_time % 60)
+ time_str = (
+ f"{minutes} minutes {seconds} seconds" if minutes > 0 else f"{seconds} seconds"
+ )
+ logger.info(f"Completed loading weights for device {device_id} in {time_str}")
+
+ threads = []
+ exceptions: list[Exception | None] = [None] * self.num_devices
+ for device_id in range(self.num_devices):
+
+ def _runner(dev_id: int) -> None:
+ try:
+ __load_weights(dev_id, model_path)
+ except Exception as exc: # pragma: no cover - surfaced after join
+ exceptions[dev_id] = exc
+
+ thread = threading.Thread(target=_runner, args=(device_id,))
+ threads.append(thread)
+ thread.start()
+ for thread in threads:
+ thread.join()
+ for device_id, exc in enumerate(exceptions):
+ if exc is not None:
+ raise RuntimeError(f"Failed to initialize device {device_id}: {exc}") from exc
+
+ if self._v2_p2p:
+ gpu0 = self._v2_p2p[0]
+ peer_bufs_cpu = torch.zeros(self.num_devices - 1, dtype=torch.int64)
+ for i in range(self.num_devices - 1):
+ dev_id = i + 1
+ peer_bufs_cpu[i] = self._v2_p2p[dev_id]["ll_buf"].data_ptr()
+ gpu0["peer_bufs"].copy_(peer_bufs_cpu)
+ logger.info(
+ "V2 P2P exchange complete: peer_bufs (ll_buf)=%s",
+ [hex(int(x)) for x in peer_bufs_cpu],
+ )
+
+ for device_id in range(self.num_devices):
+ with torch.cuda.device(device_id):
+ intermediates, caches, params, profile_logs = self._get_device_result(device_id)
+ dsa_show_hands_prepare_money(
+ params,
+ intermediates,
+ caches,
+ profile_logs,
+ self.forward_max_seq_len,
+ self.with_mtp,
+ self.is_glm5,
+ )
+ if self.with_mtp:
+ dsa_show_hands_prepare_money(
+ params[: self._base_params_count],
+ intermediates,
+ caches[: self._base_caches_count],
+ profile_logs,
+ self.forward_max_seq_len,
+ False,
+ self.is_glm5,
+ )
+
+ def from_pretrained(self, model_path: str) -> None:
+ """Load the model weights from the given path."""
+ if not os.path.exists(model_path):
+ raise ValueError(f"Model weights directory {model_path} does not exist")
+ self._init_weights(model_path)
+
+ def from_pretrained_with_cache(
+ self,
+ model_path: str,
+ cached_ffn_ops_per_device: dict[int, list],
+ skip_keys_per_device: dict[int, set[str]],
+ ) -> None:
+ """Load weights with cached MOE/MLP ops."""
+ if not os.path.exists(model_path):
+ raise ValueError(f"Model weights directory {model_path} does not exist")
+ self._init_weights(
+ model_path,
+ cached_ffn_ops_per_device=cached_ffn_ops_per_device,
+ skip_keys_per_device=skip_keys_per_device,
+ )
+
+ def init_random_weights(self) -> None:
+ """Generate random weights."""
+ self._init_weights(None)
+
+ def forward(
+ self,
+ token_id: torch.Tensor,
+ with_mtp: bool | None = None,
+ ) -> list[DeviceResult]:
+ active_mtp = with_mtp if with_mtp is not None else self.with_mtp
+ dsa_show_hands(token_id.cpu(), active_mtp, self.is_glm5)
+ return [self._get_device_result(device_id) for device_id in range(self.num_devices)]
+
+ def set_sampling_seed(self, seed: int, with_mtp: bool | None = None) -> None:
+ """Set the sampling seed for top-p sampling.
+
+ The seed is fixed for the entire request. Position provides per-step variation.
+
+ Args:
+ seed: The sampling seed value.
+ with_mtp: Override MTP mode for this call. Defaults to self.with_mtp.
+ """
+ active_mtp = with_mtp if with_mtp is not None else self.with_mtp
+ dsa_show_hands_set_sampling_seed(seed, active_mtp, self.is_glm5)
+
+ def reset_sequence(self) -> None:
+ if self.with_mtp:
+ dsa_show_hands_reset(True, self.is_glm5)
+ dsa_show_hands_reset(False, self.is_glm5)
+ else:
+ dsa_show_hands_reset(False, self.is_glm5)
+
+ def cleanup(self) -> None:
+ if self.with_mtp:
+ dsa_show_hands_go_home(True, self.is_glm5)
+ dsa_show_hands_go_home(False, self.is_glm5)
+ else:
+ dsa_show_hands_go_home(False, self.is_glm5)
+
+ def __del__(self) -> None:
+ try:
+ self.cleanup()
+ except Exception as e:
+ print(f"Exception during cleanup: {e}", file=sys.stderr)
+
+ def _get_device_result(self, device_id: int) -> DeviceResult:
+ device_result = self.multi_devices_results[device_id]
+ if device_result is None:
+ raise RuntimeError(f"Device {device_id} is not initialized")
+ return device_result
+
+ def set_prefill_valid_tokens(self, num_valid_tokens: int) -> None:
+ """Set the number of valid tokens for prefill mode.
+
+ This controls how many tokens are copied from draft_tokens to predicted_tokens
+ during prefill. Should be called before forward() when the chunk has padding.
+
+ Args:
+ num_valid_tokens: Number of valid tokens in the chunk (1-4).
+ """
+ dsa_mtp_e2e_show_hands_set_prefill_valid_tokens(num_valid_tokens, self.is_glm5)
+
+ def set_prefill_mtp_extra_token(self, token: int) -> None:
+ """Set the extra token for MTP[0] shifted input during prefill.
+
+ Args:
+ token: The prompt token at (cur_pos + mtp_seq_len).
+ """
+ dsa_mtp_e2e_show_hands_set_prefill_mtp_extra_token(token, self.is_glm5)
+
+ def get_next_draft_tokens(self, device_id: int = 0) -> torch.Tensor:
+ """Get next_draft_tokens from the specified device.
+
+ Args:
+ device_id: Device ID to get results from.
+
+ Returns:
+ next_draft_tokens tensor of shape [1, MTP_SEQ_LEN].
+ """
+ intermediates, _, _, _ = self._get_device_result(device_id)
+ return intermediates[Idx.NEXT_DRAFT_TOKENS]
+
+ def get_num_accepted(self, device_id: int = 0) -> int:
+ """Get number of accepted tokens from the specified device.
+
+ Args:
+ device_id: Device ID to get results from.
+
+ Returns:
+ Number of accepted tokens.
+ """
+ intermediates, _, _, _ = self._get_device_result(device_id)
+ return int(intermediates[Idx.ACCEPTED_TOKENS][0].item())
+
+ def get_predicted_tokens(self, device_id: int = 0) -> torch.Tensor:
+ """Get predicted_tokens from the specified device.
+
+ Args:
+ device_id: Device ID to get results from.
+
+ Returns:
+ predicted_tokens tensor containing main model predictions.
+ """
+ intermediates, _, _, _ = self._get_device_result(device_id)
+ return intermediates[Idx.PREDICTED_TOKENS]
+
+ def get_logits(self, device_id: int = 0) -> torch.Tensor:
+ """Get logits from the specified device.
+
+ Args:
+ device_id: Device ID to get results from.
+
+ Returns:
+ Logits tensor of shape [batch, seq_len, vocab_size] (FP32).
+ """
+ intermediates, _, _, _ = self._get_device_result(device_id)
+ return intermediates[Idx.LOGITS_OUT]
+
+ def get_top_n_logprobs(self, device_id: int = 0) -> tuple[torch.Tensor, torch.Tensor]:
+ """Get top-N log-probabilities and token IDs from the top_p kernel.
+
+ Args:
+ device_id: Device ID to get results from.
+
+ Returns:
+ Tuple of (log_probs, token_ids):
+ - log_probs: [batch, seq_len, 256] FP32
+ - token_ids: [batch, seq_len, 256] INT32
+ """
+ intermediates, _, _, _ = self._get_device_result(device_id)
+ return (
+ intermediates[Idx.TOP_N_LOG_PROBS],
+ intermediates[Idx.TOP_N_INDICES],
+ )
+
+ def get_token_logprob(self, device_id: int = 0) -> torch.Tensor:
+ """Get log-probability of the sampled token (from TOP_P_SCORES).
+
+ Args:
+ device_id: Device ID to get results from.
+
+ Returns:
+ Tensor of shape [batch, seq_len] (FP32).
+ """
+ intermediates, _, _, _ = self._get_device_result(device_id)
+ return intermediates[Idx.TOP_P_SCORES]
+
+ def set_logprobs_enabled(self, enabled: bool) -> None:
+ """Enable or disable logprobs export in the top_p kernel.
+
+ Args:
+ enabled: True to enable logprobs export, False to disable.
+ """
+ flag_val = 1 if enabled else 0
+ for device_id in range(self.num_devices):
+ intermediates, _, _, _ = self._get_device_result(device_id)
+ intermediates[Idx.LOGPROBS_FLAG].fill_(flag_val)
diff --git a/tilert/models/glm_5/_dsa_v32/modules/mla_v2.py b/tilert/models/glm_5/_dsa_v32/modules/mla_v2.py
new file mode 100644
index 0000000..d9a9dd1
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/modules/mla_v2.py
@@ -0,0 +1,248 @@
+"""MLA weight generator classes for device-group-specific pipelines."""
+
+import torch
+
+from tilert.models.base import SerializableTileRTModule
+from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
+from tilert.models.glm_5._dsa_v32.ops.layernorm_rope_rotate import LayerNormRoPERotate
+from tilert.models.glm_5._dsa_v32.ops.projo_wkvb import ProjoWKVb
+from tilert.models.glm_5._dsa_v32.ops.projq_wqb import ProjqWqb
+from tilert.models.glm_5._dsa_v32.ops.projx_wis import ProjxWis
+from tilert.models.glm_5._dsa_v32.ops.rmsnorm_kv import KVRMSNorm
+from tilert.models.glm_5._dsa_v32.ops.rmsnorm_projq_wqb import (
+ RmsnormProjqWqb,
+ RmsnormProjqWqbAlgorithm,
+)
+from tilert.models.glm_5._dsa_v32.ops.rmsnorm_projq_wqi import (
+ RmsnormProjqWqi,
+ RmsnormProjqWqiAlgorithm,
+)
+from tilert.models.glm_5._dsa_v32.ops.rmsnorm_projx_wqakis import (
+ RMSNormProjxWqakis,
+)
+from tilert.models.glm_5._dsa_v32.ops.rmsnorm_projx_wqkva import (
+ RMSNormProjxWqkva,
+ RMSNormProjxWqkvaAlgorithm,
+)
+from tilert.models.glm_5._dsa_v32.ops.unproj_o_allreduce import (
+ UnProjOAllReduce,
+ UnProjOAllReduceAlgorithm,
+)
+
+
+class SparseSelectMlaV2(SerializableTileRTModule):
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ device_id: int,
+ num_devices: int,
+ peer_bufs: torch.Tensor | None = None,
+ partial_buf: torch.Tensor | None = None,
+ ):
+ super().__init__(model_args=model_args, device_id=device_id, num_devices=num_devices)
+
+ self.rmsnorm_projx_wqakis = RMSNormProjxWqakis(
+ model_args=model_args, device_id=device_id, num_devices=num_devices
+ )
+ self.register_op(self.rmsnorm_projx_wqakis)
+
+ self.rmsnorm_projq_wqi = RmsnormProjqWqi(
+ model_args=model_args, device_id=device_id, num_devices=num_devices
+ )
+ self.rmsnorm_projq_wqi.algorithm = RmsnormProjqWqiAlgorithm.FP16MMA
+ self.register_op(self.rmsnorm_projq_wqi)
+
+ self.layernorm_rope_rotate = LayerNormRoPERotate(
+ model_args=model_args, device_id=device_id, num_devices=num_devices
+ )
+ self.register_op(self.layernorm_rope_rotate)
+
+ self.projx_wis = ProjxWis(
+ model_args=model_args, device_id=device_id, num_devices=num_devices
+ )
+ self.register_op(self.projx_wis)
+
+ self.peer_bufs = peer_bufs
+ self.partial_buf = partial_buf
+
+ self.ki_cache: torch.Tensor | None = None
+ self.kv_cache: torch.Tensor | None = None
+ self.pe_cache: torch.Tensor | None = None
+
+ def get_weights_list(self) -> list[torch.Tensor]:
+ """Return weight tensors."""
+ weights = super().get_weights_list()
+
+ dev = f"cuda:{self.device_id}"
+ if self.peer_bufs is None:
+ self.peer_bufs = torch.zeros(self.num_devices - 1, dtype=torch.int64, device=dev)
+ if self.partial_buf is None:
+ self.partial_buf = torch.zeros(
+ self.model_args.max_batch_size,
+ 4,
+ self.model_args.dim,
+ dtype=torch.bfloat16,
+ device=dev,
+ )
+
+ weights.append(self.peer_bufs)
+ weights.append(self.partial_buf)
+
+ return weights
+
+ def get_cache_vars(self) -> list[torch.Tensor]:
+ """Return [ki_cache, kv_cache, pe_cache] matching DsaCacheVars layout."""
+ cache_seq_len = self.model_args.max_seq_len + self.model_args.kv_cache_pad
+ bs_args = (self.model_args.max_batch_size, cache_seq_len)
+
+ if self.ki_cache is None:
+ ki_dim = self.model_args.index_head_dim
+ self.ki_cache = torch.zeros(
+ *bs_args, ki_dim, dtype=torch.bfloat16, device=f"cuda:{self.device_id}"
+ )
+ if self.kv_cache is None:
+ kv_dim = self.model_args.kv_lora_rank
+ self.kv_cache = torch.zeros(
+ *bs_args, kv_dim, dtype=torch.bfloat16, device=f"cuda:{self.device_id}"
+ )
+ if self.pe_cache is None:
+ pe_dim = self.model_args.qk_rope_head_dim
+ self.pe_cache = torch.zeros(
+ *bs_args, pe_dim, dtype=torch.bfloat16, device=f"cuda:{self.device_id}"
+ )
+ return [*super().get_cache_vars(), self.ki_cache, self.kv_cache, self.pe_cache]
+
+
+class PureMlaV2(SerializableTileRTModule):
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ device_id: int,
+ num_devices: int,
+ ll_buf: torch.Tensor | None = None,
+ ):
+ super().__init__(model_args=model_args, device_id=device_id, num_devices=num_devices)
+
+ self.rmsnorm_projx_wqkva = RMSNormProjxWqkva(
+ model_args=model_args, device_id=device_id, num_devices=num_devices
+ )
+ self.rmsnorm_projx_wqkva.algorithm = RMSNormProjxWqkvaAlgorithm.DECOUPLED
+ self.register_op(self.rmsnorm_projx_wqkva)
+
+ self.rmsnorm_projq_wqb = RmsnormProjqWqb(
+ model_args=model_args, device_id=device_id, num_devices=num_devices
+ )
+ self.rmsnorm_projq_wqb.algorithm = RmsnormProjqWqbAlgorithm.FP16MMA
+ self.register_op(self.rmsnorm_projq_wqb)
+
+ self.rmsnorm_kv = KVRMSNorm(
+ model_args=model_args, device_id=device_id, num_devices=num_devices
+ )
+ self.register_op(self.rmsnorm_kv)
+
+ self.projq_wqb = ProjqWqb(
+ model_args=model_args, device_id=device_id, num_devices=num_devices
+ )
+ self.register_op(self.projq_wqb)
+
+ self.projo_wkvb = ProjoWKVb(
+ model_args=model_args, device_id=device_id, num_devices=num_devices
+ )
+ self.register_op(self.projo_wkvb)
+
+ allreduce_algo = UnProjOAllReduceAlgorithm.FP16MMA
+ self.unproj_o_allreduce = UnProjOAllReduce(
+ model_args=model_args,
+ device_id=device_id,
+ num_devices=num_devices,
+ algorithm=allreduce_algo,
+ )
+ self.register_op(self.unproj_o_allreduce)
+
+ self.ll_buf = ll_buf
+
+ self.ki_cache: torch.Tensor | None = None
+ self.kv_cache: torch.Tensor | None = None
+ self.pe_cache: torch.Tensor | None = None
+
+ def init_random_weights(self) -> None:
+ """Initialize random weights for this module."""
+ super().init_random_weights()
+
+ from tilert.models.common import init_func
+
+ for op in [self.projq_wqb, self.projo_wkvb]:
+ padded_total = op.num_local_heads * op.num_devices
+ w = init_func(
+ torch.empty(
+ padded_total * op.wkvb_head_dim, op.wkvb_lora_rank, dtype=torch.float8_e4m3fn
+ )
+ )
+ s = init_func(
+ torch.empty(
+ padded_total * op.wkvb_head_dim // op.model_args.block_size,
+ op.wkvb_lora_rank_qsize,
+ dtype=torch.float32,
+ )
+ )
+ ref_dict = dict(zip(op.ref_weights_alias(), [w, s]))
+ op.init_reference_weights(ref_dict)
+ sharded = op.device_sharding(ref_dict)
+ per_dev = {k: v[op.device_id] for k, v in sharded.items()}
+ op.init_tilert_weights_hmma(per_dev)
+
+ def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ """Load TileRT weights for this module from state_dict."""
+ self.projq_wqb.is_tilert_weights_init = True
+ self.projo_wkvb.is_tilert_weights_init = True
+
+ super().init_tilert_weights(state_dict)
+
+ for op in [self.projq_wqb, self.projo_wkvb]:
+ op_state_dict = {}
+ for op_key in op.get_tilert_weights_alias():
+ for p, s in zip(self.prefix_seq, self.suffix_seq):
+ original_key = f"{p}{op_key}{s}"
+ if original_key in state_dict:
+ op_state_dict[op_key] = state_dict[original_key]
+ break
+ op.is_tilert_weights_init = False
+ op.init_tilert_weights_hmma(op_state_dict)
+
+ def get_weights_list(self) -> list[torch.Tensor]:
+ """Return weight tensors."""
+ weights = super().get_weights_list()
+
+ if self.ll_buf is None:
+ max_seq_len = getattr(self.model_args, "num_mtp", 3) + 1
+ topk = self.model_args.index_topk
+ self.ll_buf = torch.zeros(
+ max_seq_len * topk * 2, dtype=torch.int32, device=f"cuda:{self.device_id}"
+ )
+
+ weights.append(self.ll_buf)
+
+ return weights
+
+ def get_cache_vars(self) -> list[torch.Tensor]:
+ """Return [ki_cache, kv_cache, pe_cache] matching DsaCacheVars layout."""
+ cache_seq_len = self.model_args.max_seq_len + self.model_args.kv_cache_pad
+ bs_args = (self.model_args.max_batch_size, cache_seq_len)
+
+ if self.ki_cache is None:
+ ki_dim = self.model_args.index_head_dim
+ self.ki_cache = torch.zeros(
+ *bs_args, ki_dim, dtype=torch.bfloat16, device=f"cuda:{self.device_id}"
+ )
+ if self.kv_cache is None:
+ kv_dim = self.model_args.kv_lora_rank
+ self.kv_cache = torch.zeros(
+ *bs_args, kv_dim, dtype=torch.bfloat16, device=f"cuda:{self.device_id}"
+ )
+ if self.pe_cache is None:
+ pe_dim = self.model_args.qk_rope_head_dim
+ self.pe_cache = torch.zeros(
+ *bs_args, pe_dim, dtype=torch.bfloat16, device=f"cuda:{self.device_id}"
+ )
+ return [*super().get_cache_vars(), self.ki_cache, self.kv_cache, self.pe_cache]
diff --git a/tilert/models/glm_5/_dsa_v32/modules/mlp.py b/tilert/models/glm_5/_dsa_v32/modules/mlp.py
new file mode 100644
index 0000000..85fec25
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/modules/mlp.py
@@ -0,0 +1,74 @@
+from tilert.models.base import SerializableTileRTModule
+from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
+from tilert.models.glm_5._dsa_v32.modules.mla_v2 import PureMlaV2 as Mla
+from tilert.models.glm_5._dsa_v32.ops.down_allreduce import (
+ DownAllReduce,
+)
+from tilert.models.glm_5._dsa_v32.ops.rmsnorm_up_gate_silu import (
+ RMSNormUpGateSiLU,
+ RMSNormUpGateSiLUAlgorithm,
+)
+
+
+class Mlp(SerializableTileRTModule):
+ """Implement the MLP operations."""
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ device_id: int,
+ num_devices: int,
+ ):
+ super().__init__(model_args=model_args, device_id=device_id, num_devices=num_devices)
+
+ self.rmsnorm_mlp_up_gate_silu = RMSNormUpGateSiLU(
+ model_args=model_args,
+ device_id=device_id,
+ num_devices=num_devices,
+ )
+ self.rmsnorm_mlp_up_gate_silu.algorithm = RMSNormUpGateSiLUAlgorithm.FP16MMA
+ self.register_op(self.rmsnorm_mlp_up_gate_silu)
+
+ self.rmsnorm_mlp_down = DownAllReduce(
+ model_args=model_args, device_id=device_id, num_devices=num_devices
+ )
+ self.register_op(self.rmsnorm_mlp_down)
+
+
+class MlpBlock(SerializableTileRTModule):
+ """Implement the MOE block operations."""
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ device_id: int,
+ num_devices: int,
+ remove_selected: bool = False,
+ mla_cls: type | None = None,
+ mla_num_devices: int | None = None,
+ mla_kwargs: dict | None = None,
+ mlp: "Mlp | None" = None,
+ ):
+ super().__init__(
+ model_args=model_args,
+ device_id=device_id,
+ num_devices=num_devices,
+ remove_selected=remove_selected,
+ )
+
+ mla_class = mla_cls or Mla
+ mla_nd = mla_num_devices if mla_num_devices is not None else num_devices
+ self.mla = mla_class(
+ model_args=model_args, device_id=device_id, num_devices=mla_nd, **(mla_kwargs or {})
+ )
+ self.register_op(self.mla)
+ self.mlp = (
+ mlp
+ if mlp is not None
+ else Mlp(
+ model_args=model_args,
+ device_id=device_id,
+ num_devices=num_devices,
+ )
+ )
+ self.register_op(self.mlp)
diff --git a/tilert/models/glm_5/_dsa_v32/modules/moe.py b/tilert/models/glm_5/_dsa_v32/modules/moe.py
new file mode 100644
index 0000000..5410284
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/modules/moe.py
@@ -0,0 +1,80 @@
+import torch
+
+from tilert.models.base import SerializableTileRTModule
+from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
+from tilert.models.glm_5._dsa_v32.modules.mla_v2 import PureMlaV2 as Mla
+from tilert.models.glm_5._dsa_v32.ops.expert_down_allreduce import (
+ ExpertDownAllReduce,
+)
+from tilert.models.glm_5._dsa_v32.ops.expert_sel_up_gate_silu import (
+ ExpertSelectUpGateSiLU,
+ ExpertSelectUpGateSiLUAlgorithm,
+)
+from tilert.models.glm_5._dsa_v32.ops.rmsnorm_expert_proj import (
+ RMSNormExpertProj,
+)
+
+
+class Moe(SerializableTileRTModule):
+ """Implement the MOE operations."""
+
+ rmsnorm_expert_proj: RMSNormExpertProj
+
+ def __init__(self, model_args: ModelArgs, device_id: int, num_devices: int):
+ super().__init__(model_args=model_args, device_id=device_id, num_devices=num_devices)
+
+ self.rmsnorm_expert_proj = RMSNormExpertProj(
+ model_args=model_args, device_id=device_id, num_devices=num_devices
+ )
+ self.register_op(self.rmsnorm_expert_proj)
+
+ self.exp_sel_up_gate_silu = ExpertSelectUpGateSiLU(
+ model_args=model_args,
+ device_id=device_id,
+ num_devices=num_devices,
+ algorithm=ExpertSelectUpGateSiLUAlgorithm.FP16MMA,
+ )
+ self.register_op(self.exp_sel_up_gate_silu)
+
+ self.expert_down_allreduce = ExpertDownAllReduce(
+ model_args=model_args, device_id=device_id, num_devices=num_devices
+ )
+ self.register_op(self.expert_down_allreduce)
+
+ def get_weights_list(self) -> list[torch.Tensor]:
+ return super().get_weights_list()
+
+
+class MoeBlock(SerializableTileRTModule):
+ """Implement the MOE block operations."""
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ device_id: int,
+ num_devices: int,
+ remove_selected: bool = False,
+ mla_cls: type | None = None,
+ mla_num_devices: int | None = None,
+ mla_kwargs: dict | None = None,
+ moe: "Moe | None" = None,
+ ):
+ super().__init__(
+ model_args=model_args,
+ device_id=device_id,
+ num_devices=num_devices,
+ remove_selected=remove_selected,
+ )
+
+ mla_class = mla_cls or Mla
+ mla_nd = mla_num_devices if mla_num_devices is not None else num_devices
+ self.mla = mla_class(
+ model_args=model_args, device_id=device_id, num_devices=mla_nd, **(mla_kwargs or {})
+ )
+ self.register_op(self.mla)
+ self.moe = (
+ moe
+ if moe is not None
+ else Moe(model_args=model_args, device_id=device_id, num_devices=num_devices)
+ )
+ self.register_op(self.moe)
diff --git a/tilert/models/glm_5/_dsa_v32/modules/mtp.py b/tilert/models/glm_5/_dsa_v32/modules/mtp.py
new file mode 100644
index 0000000..ccfbdc8
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/modules/mtp.py
@@ -0,0 +1,62 @@
+import torch
+
+from tilert.models.base import SerializableTileRTModule
+from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
+from tilert.models.glm_5._dsa_v32.modules.moe import MoeBlock
+from tilert.models.glm_5._dsa_v32.modules.mtp_preprocess import MTPPreprocessLayer
+from tilert.models.glm_5._dsa_v32.ops import RMSNormHeadProj
+
+
+class MTP(SerializableTileRTModule):
+ """MTP module."""
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ device_id: int,
+ num_devices: int,
+ mla_cls: type | None = None,
+ mla_num_devices: int | None = None,
+ mla_kwargs: dict | None = None,
+ ):
+ super().__init__(model_args=model_args, device_id=device_id, num_devices=num_devices)
+
+ self.embed_tokens_weight = None
+ self.freqs_cis = None
+
+ mtp_layer_id = self.model_args.n_layers
+ self.register_op(
+ MTPPreprocessLayer(self.model_args, self.num_devices, device_id),
+ prefix=f"layer_{mtp_layer_id}_",
+ suffix=f"_dev_{device_id}",
+ )
+ self.register_op(
+ MoeBlock(
+ model_args=model_args,
+ device_id=device_id,
+ num_devices=num_devices,
+ mla_cls=mla_cls,
+ mla_num_devices=mla_num_devices,
+ mla_kwargs=mla_kwargs,
+ ),
+ prefix=f"layer_{mtp_layer_id}_",
+ suffix=f"_dev_{device_id}",
+ )
+ self.register_op(
+ RMSNormHeadProj(model_args=model_args, device_id=device_id, num_devices=num_devices),
+ prefix=f"layer_{mtp_layer_id}_",
+ suffix=f"_dev_{device_id}",
+ retain_weights=True,
+ )
+
+ def init_tilert_weights(self, state_dicts: dict[str, torch.Tensor]) -> None:
+ self.embed_tokens_weight = state_dicts["model.embed_tokens.weight"]
+ self.freqs_cis = state_dicts["freqs_cis"]
+ super().init_tilert_weights(state_dicts)
+
+ def get_weights_list(self) -> list[torch.Tensor]:
+ return [
+ self.embed_tokens_weight,
+ self.freqs_cis,
+ *super().get_weights_list(),
+ ]
diff --git a/tilert/models/glm_5/_dsa_v32/modules/mtp_preprocess.py b/tilert/models/glm_5/_dsa_v32/modules/mtp_preprocess.py
new file mode 100644
index 0000000..debd75d
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/modules/mtp_preprocess.py
@@ -0,0 +1,238 @@
+"""MTP preprocess layer for DeepSeek v3."""
+
+from dataclasses import dataclass
+
+import torch
+
+from tilert.models.base import TileRTModule, TilertWeightsConverter
+from tilert.models.common import init_func, linear
+from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
+
+__all__ = [
+ "mtp_preprocess_layer",
+ "MTPPreprocessLayer",
+ "MTPPreprocessRefWeightsAlias",
+ "MTPPreprocessTilertWeightsAlias",
+ "MTPPreprocessWeightsConverter",
+]
+
+
+def mtp_preprocess_layer(
+ params: list[torch.Tensor],
+ temp_vars: list[torch.Tensor],
+ profile_logs: torch.Tensor,
+) -> torch.Tensor:
+ """MTP preprocess layer op for DeepSeek v3."""
+ return torch.ops.tilert.mtp_preprocess_layer(params, temp_vars, profile_logs)
+
+
+@dataclass
+class MTPPreprocessRefWeightsAlias:
+ """Reference (golden/PyTorch) weight keys for MTP preprocess."""
+
+ embedding_rmsnorm = "enorm.weight"
+ hidden_rmsnorm = "hnorm.weight"
+ eh_proj = "eh_proj.weight"
+
+ @property
+ def ref_tensor_alias(self) -> list[str]:
+ return [
+ self.embedding_rmsnorm,
+ self.hidden_rmsnorm,
+ self.eh_proj,
+ ]
+
+ def __call__(self) -> list[str]:
+ return self.ref_tensor_alias
+
+
+@dataclass
+class MTPPreprocessTilertWeightsAlias:
+ """TileRT weight keys for MTP preprocess."""
+
+ embedding_rmsnorm_gamma = "embedding_rmsnorm_gamma"
+ hidden_rmsnorm_gamma = "hidden_rmsnorm_gamma"
+ eh_proj_weights = "eh_proj_weights"
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return [
+ self.embedding_rmsnorm_gamma,
+ self.hidden_rmsnorm_gamma,
+ self.eh_proj_weights,
+ ]
+
+ def __call__(self) -> list[str]:
+ return self.tilert_tensor_alias
+
+
+class MTPPreprocessWeightsConverter(TilertWeightsConverter):
+ """Converts ref-format weights to TileRT format for MTP preprocess."""
+
+ def convert_to_tilert(self, weights: list[torch.Tensor], device_id: int) -> list[torch.Tensor]:
+ """
+ Convert ref weights to TileRT format for a specific device.
+
+ Args:
+ weights: [embedding_rmsnorm_gamma, hidden_rmsnorm_gamma, eh_proj.weight]
+ Ref format: enorm.weight [7168], hnorm.weight [7168],
+ eh_proj.weight [7168, 14336].
+ device_id: Target device ID for weight placement.
+
+ Returns:
+ MTPPreprocessParams with converted weights for device_id.
+ """
+ device = torch.device(f"cuda:{device_id}")
+ embedding_rmsnorm_gamma, hidden_rmsnorm_gamma, eh_proj_weight = weights
+
+ embedding_rmsnorm_gamma = embedding_rmsnorm_gamma.to(device=device, dtype=torch.float32)
+ hidden_rmsnorm_gamma = hidden_rmsnorm_gamma.to(device=device, dtype=torch.float32)
+ eh_proj_weights = (
+ eh_proj_weight.reshape(
+ 128, self.model_args.dim // 128, self.model_args.dim * 2 // 256 // 8, 256
+ )
+ .transpose(1, 2)
+ .contiguous()
+ .to(device=device, dtype=torch.bfloat16)
+ )
+ return [embedding_rmsnorm_gamma, hidden_rmsnorm_gamma, eh_proj_weights]
+
+
+class MTPPreprocessLayer(TileRTModule):
+ """MTP preprocess layer: RMSNorm(embedding), RMSNorm(hidden), concat & project."""
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ num_devices: int,
+ device_id: int,
+ ref_weights_alias: MTPPreprocessRefWeightsAlias | None = None,
+ ) -> None:
+ super().__init__(
+ self.__class__.__name__,
+ model_args=model_args,
+ num_devices=num_devices,
+ device_id=device_id,
+ )
+ self.tilert_weights_alias = MTPPreprocessTilertWeightsAlias()
+ self.ref_weights_alias = (
+ ref_weights_alias if ref_weights_alias is not None else MTPPreprocessRefWeightsAlias()
+ )
+ self.hidden_size = model_args.dim
+
+ self.tilert_embedding_rmsnorm_gamma: torch.Tensor | None = None
+ self.tilert_hidden_rmsnorm_gamma: torch.Tensor | None = None
+ self.tilert_eh_proj_weights: torch.Tensor | None = None
+
+ self.ref_embedding_rmsnorm_gamma: torch.Tensor | None = None
+ self.ref_hidden_rmsnorm_gamma: torch.Tensor | None = None
+ self.ref_eh_proj_weight: torch.Tensor | None = None
+
+ def get_weights_list(self) -> list[torch.Tensor]:
+ return [
+ self.tilert_embedding_rmsnorm_gamma,
+ self.tilert_hidden_rmsnorm_gamma,
+ self.tilert_eh_proj_weights,
+ ]
+
+ def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ """Repeat ref weights for each device (for init_tilert_weights from ref)."""
+ embedding_gamma = weights_map[self.ref_weights_alias.embedding_rmsnorm]
+ hidden_gamma = weights_map[self.ref_weights_alias.hidden_rmsnorm]
+ eh_proj_weights = weights_map[self.ref_weights_alias.eh_proj]
+ return {
+ self.tilert_weights_alias.embedding_rmsnorm_gamma: (
+ embedding_gamma[None, ...].repeat(self.num_devices, 1)
+ ),
+ self.tilert_weights_alias.hidden_rmsnorm_gamma: (
+ hidden_gamma[None, ...].repeat(self.num_devices, 1)
+ ),
+ self.tilert_weights_alias.eh_proj_weights: (
+ eh_proj_weights[None, ...]
+ .reshape(
+ self.model_args.dim,
+ self.num_devices,
+ self.model_args.dim * 2 // self.num_devices,
+ )
+ .transpose(0, 1)
+ ),
+ }
+
+ def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ """Load ref-format weights (enorm.weight, hnorm.weight, eh_proj.weight)."""
+ self.ref_embedding_rmsnorm_gamma = state_dict[self.ref_weights_alias.embedding_rmsnorm]
+ self.ref_hidden_rmsnorm_gamma = state_dict[self.ref_weights_alias.hidden_rmsnorm]
+ self.ref_eh_proj_weight = state_dict[self.ref_weights_alias.eh_proj]
+
+ def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ """
+ Load TileRT weights from state_dict.
+
+ state_dict may use:
+ - Full keys: layer_{layer_id}_{alias}_dev_{device_id}
+ - Short keys: embedding_rmsnorm_gamma, hidden_rmsnorm_gamma, eh_proj_weights
+ - Ref keys: enorm.weight, hnorm.weight, eh_proj.weight (then convert)
+ """
+ converter = MTPPreprocessWeightsConverter(self.model_args, self.num_devices)
+ params = converter.convert_to_tilert(
+ [state_dict[k] for k in self.tilert_weights_alias()], self.device_id
+ )
+ self.tilert_embedding_rmsnorm_gamma = params[0]
+ self.tilert_hidden_rmsnorm_gamma = params[1]
+ self.tilert_eh_proj_weights = params[2]
+
+ def init_random_weights(self) -> dict[str, torch.Tensor]:
+ """Initialize random ref weights and convert to TileRT for this device."""
+ embedding_gamma = init_func(torch.randn(self.hidden_size, dtype=torch.float32))
+ hidden_gamma = init_func(torch.randn(self.hidden_size, dtype=torch.float32))
+ eh_proj_weights = init_func(
+ torch.randn(self.hidden_size, self.hidden_size * 2, dtype=torch.bfloat16)
+ )
+ return {
+ self.ref_weights_alias.embedding_rmsnorm: embedding_gamma,
+ self.ref_weights_alias.hidden_rmsnorm: hidden_gamma,
+ self.ref_weights_alias.eh_proj: eh_proj_weights,
+ }
+
+ def golden_forward(
+ self,
+ x: torch.Tensor,
+ last_hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ """
+ Reference forward: enorm(x), hnorm(last_hidden), concat & eh_proj.
+
+ Args:
+ x: [batch, seq_len, hidden_size] embedded tokens
+ last_hidden_states: [batch, seq_len, hidden_size] previous hidden
+
+ Returns:
+ [batch, seq_len, hidden_size] projected hidden
+ """
+ assert self.ref_embedding_rmsnorm_gamma is not None
+ assert self.ref_hidden_rmsnorm_gamma is not None
+ assert self.ref_eh_proj_weight is not None
+
+ future_norm = torch.nn.functional.rms_norm(
+ x.float(),
+ [x.size(-1)],
+ self.ref_embedding_rmsnorm_gamma,
+ 1e-6,
+ )
+ prev_norm = torch.nn.functional.rms_norm(
+ last_hidden_states.float(),
+ [last_hidden_states.size(-1)],
+ self.ref_hidden_rmsnorm_gamma,
+ 1e-6,
+ )
+ combined = torch.cat([future_norm, prev_norm], dim=-1)
+ return linear(combined, self.ref_eh_proj_weight)
+
+ def tilert_forward(
+ self,
+ params: list[torch.Tensor],
+ temp_vars: list[torch.Tensor],
+ profile_logs: torch.Tensor,
+ ) -> torch.Tensor:
+ """Run TileRT mtp_preprocess_layer op."""
+ return mtp_preprocess_layer(params, temp_vars, profile_logs)
diff --git a/tilert/models/glm_5/_dsa_v32/ops/__init__.py b/tilert/models/glm_5/_dsa_v32/ops/__init__.py
new file mode 100644
index 0000000..a58dab8
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/ops/__init__.py
@@ -0,0 +1,160 @@
+"""Core operations for deepseek v3.2."""
+
+from tilert.models.glm_5._dsa_v32.ops.broadcast_selected_token_ids import (
+ broadcast_selected_token_ids,
+)
+from tilert.models.glm_5._dsa_v32.ops.down_allreduce import (
+ DownAllReduce,
+ DownAllReduceAlgorithm,
+ down_allreduce,
+)
+from tilert.models.glm_5._dsa_v32.ops.eh_proj_allreduce import (
+ EHProjAllReduce,
+ EHProjAllReduceAlgorithm,
+ eh_proj_allreduce,
+)
+from tilert.models.glm_5._dsa_v32.ops.expert_down_allreduce import (
+ ExpertDownAllReduce,
+ ExpertDownAllReduceAlgorithm,
+ expert_down_allreduce,
+)
+from tilert.models.glm_5._dsa_v32.ops.expert_sel_up_gate_silu import (
+ ExpertSelectUpGateSiLU,
+ ExpertSelectUpGateSiLUAlgorithm,
+)
+from tilert.models.glm_5._dsa_v32.ops.flash_sparse_mla import (
+ FlashSparseMLACombineAlgorithm,
+ flash_sparse_mla,
+)
+from tilert.models.glm_5._dsa_v32.ops.layernorm_rope_rotate import (
+ LayerNormRoPERotateAlgorithm,
+ layernorm_rope_rotate,
+)
+from tilert.models.glm_5._dsa_v32.ops.padded_allreduce_add import (
+ PaddedAllReduceAdd,
+ PaddedAllReduceAddAlgorithm,
+ padded_allreduce_add,
+)
+from tilert.models.glm_5._dsa_v32.ops.projo_wkvb import ProjoWKVbAlgorithm, projo_wkvb
+from tilert.models.glm_5._dsa_v32.ops.projq_wqb import ProjqWqbAlgorithm, projq_wqb
+from tilert.models.glm_5._dsa_v32.ops.projx_wis import ProjxWisAlgorithm, projx_wis
+from tilert.models.glm_5._dsa_v32.ops.qkv_rope import (
+ QKVRoPE,
+ QKVRoPEAlgorithm,
+ QKVRoPERefWeightsAlias,
+ QKVRoPETilertWeightsAlias,
+ qkv_rope,
+)
+from tilert.models.glm_5._dsa_v32.ops.receive_selected_token_ids import (
+ receive_selected_token_ids,
+)
+from tilert.models.glm_5._dsa_v32.ops.rmsnorm_expert_proj import (
+ RMSNormExpertProj,
+ RMSNormExpertProjAlgorithm,
+)
+from tilert.models.glm_5._dsa_v32.ops.rmsnorm_head_proj import (
+ RMSNormHeadProj,
+ RMSNormHeadProjAlgorithm,
+)
+from tilert.models.glm_5._dsa_v32.ops.rmsnorm_kv import KVRMSNormAlgorithm, rmsnorm_kv
+from tilert.models.glm_5._dsa_v32.ops.rmsnorm_projq_wqb import (
+ RmsnormProjqWqb,
+ RmsnormProjqWqbAlgorithm,
+ RmsnormProjqWqbWeightsConverter,
+)
+from tilert.models.glm_5._dsa_v32.ops.rmsnorm_projq_wqi import (
+ RmsnormProjqWqi,
+ RmsnormProjqWqiAlgorithm,
+ RmsnormProjqWqiWeightsConverter,
+)
+from tilert.models.glm_5._dsa_v32.ops.rmsnorm_projx_wqakis import (
+ RMSNormProjxWqakis,
+ RMSNormProjxWqakisAlgorithm,
+)
+from tilert.models.glm_5._dsa_v32.ops.rmsnorm_projx_wqkva import (
+ RMSNormProjxWqkva,
+ RMSNormProjxWqkvaAlgorithm,
+)
+from tilert.models.glm_5._dsa_v32.ops.rmsnorm_quant import rmsnorm_quant
+from tilert.models.glm_5._dsa_v32.ops.rmsnorm_up_gate_silu import (
+ RMSNormUpGateSiLU,
+ RMSNormUpGateSiLUAlgorithm,
+)
+from tilert.models.glm_5._dsa_v32.ops.rotate import (
+ Rotate,
+ RotateAlgorithm,
+ RotateRefWeightsAlias,
+ RotateTilertWeightsAlias,
+ rotate,
+ rotate_activation,
+)
+from tilert.models.glm_5._dsa_v32.ops.sparse_index import sparse_index, sparse_index_topk
+from tilert.models.glm_5._dsa_v32.ops.topk import TopK, topk_accurate, topk_approximate
+from tilert.models.glm_5._dsa_v32.ops.unproj_o_allreduce import (
+ UnProjOAllReduce,
+ UnProjOAllReduceAlgorithm,
+ unproj_o_allreduce,
+)
+
+__all__ = [
+ "down_allreduce",
+ "DownAllReduce",
+ "DownAllReduceAlgorithm",
+ "expert_down_allreduce",
+ "ExpertDownAllReduce",
+ "ExpertDownAllReduceAlgorithm",
+ "rmsnorm_kv",
+ "KVRMSNormAlgorithm",
+ "unproj_o_allreduce",
+ "projo_wkvb",
+ "ProjoWKVbAlgorithm",
+ "projq_wqb",
+ "ProjqWqbAlgorithm",
+ "rotate",
+ "rotate_activation",
+ "Rotate",
+ "RotateAlgorithm",
+ "RotateRefWeightsAlias",
+ "RotateTilertWeightsAlias",
+ "layernorm_rope_rotate",
+ "LayerNormRoPERotateAlgorithm",
+ "TopK",
+ "topk_approximate",
+ "topk_accurate",
+ "sparse_index",
+ "sparse_index_topk",
+ "flash_sparse_mla",
+ "FlashSparseMLACombineAlgorithm",
+ "projx_wis",
+ "ProjxWisAlgorithm",
+ "qkv_rope",
+ "QKVRoPE",
+ "QKVRoPEAlgorithm",
+ "QKVRoPERefWeightsAlias",
+ "QKVRoPETilertWeightsAlias",
+ "eh_proj_allreduce",
+ "EHProjAllReduceAlgorithm",
+ "rmsnorm_quant",
+ "RmsnormProjqWqi",
+ "RmsnormProjqWqiAlgorithm",
+ "RmsnormProjqWqiWeightsConverter",
+ "RMSNormExpertProj",
+ "RMSNormExpertProjAlgorithm",
+ "RMSNormProjxWqakis",
+ "RMSNormProjxWqakisAlgorithm",
+ "RMSNormProjxWqkva",
+ "RMSNormProjxWqkvaAlgorithm",
+ "RMSNormUpGateSiLU",
+ "RMSNormUpGateSiLUAlgorithm",
+ "UnProjOAllReduce",
+ "UnProjOAllReduceAlgorithm",
+ "RMSNormHeadProj",
+ "RMSNormHeadProjAlgorithm",
+ "ExpertSelectUpGateSiLU",
+ "ExpertSelectUpGateSiLUAlgorithm",
+ "PaddedAllReduceAdd",
+ "PaddedAllReduceAddAlgorithm",
+ "padded_allreduce_add",
+ "broadcast_selected_token_ids",
+ "receive_selected_token_ids",
+]
diff --git a/tilert/models/glm_5/_dsa_v32/ops/broadcast_selected_token_ids.py b/tilert/models/glm_5/_dsa_v32/ops/broadcast_selected_token_ids.py
new file mode 100644
index 0000000..f6bf2a8
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/ops/broadcast_selected_token_ids.py
@@ -0,0 +1,36 @@
+"""BroadcastSelectedTokenIds — P2P broadcast of idx_selects from GPU 0 to peers."""
+
+import torch
+
+__all__ = [
+ "broadcast_selected_token_ids",
+]
+
+
+def broadcast_selected_token_ids(
+ idx_selects: torch.Tensor,
+ peer_bufs: torch.Tensor,
+ flag_val: int,
+ profile_logs: torch.Tensor,
+ model_arch: str,
+ compute_kernel_type: str = "bf16",
+) -> None:
+ """Broadcast idx_selects [1,S,2048] int32 from GPU 0 to peer GPUs.
+
+ Args:
+ idx_selects: Source tensor [1, S, 2048] int32 on GPU 0.
+ peer_bufs: Device pointer array [N] int64 — each entry is a peer
+ buffer address.
+ flag_val: Synchronization flag value.
+ profile_logs: Profile logs tensor.
+ model_arch: Model architecture ("deepseek_v3_2" or "glm_5").
+ compute_kernel_type: Compute kernel type ("bf16").
+ """
+ torch.ops.tilert.broadcast_selected_token_ids_op(
+ idx_selects,
+ peer_bufs,
+ flag_val,
+ model_arch,
+ compute_kernel_type,
+ profile_logs,
+ )
diff --git a/tilert/models/glm_5/_dsa_v32/ops/down_allreduce.py b/tilert/models/glm_5/_dsa_v32/ops/down_allreduce.py
new file mode 100644
index 0000000..38b305c
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/ops/down_allreduce.py
@@ -0,0 +1,343 @@
+"""DownAllreduce operation module."""
+
+from dataclasses import dataclass
+from enum import Enum
+
+import torch
+
+from tilert.models.base import TileRTModule
+from tilert.models.common import weight_dequant
+from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
+from tilert.models.glm_5._dsa_v32.ops.expert_down_allreduce import (
+ ExpertDownAllReduceWeightsConverter,
+)
+from tilert.utils import get_profile_log_tensor
+
+__all__ = [
+ "down_allreduce",
+ "DownAllReduceAlgorithm",
+ "DownAllReduce",
+ "DownAllReduceTilertWeightsAlias",
+]
+
+
+def down_allreduce(
+ vec_in: torch.Tensor,
+ mat_in: torch.Tensor,
+ mat_scale: torch.Tensor,
+ x_in: torch.Tensor,
+ flag: int,
+ vec_out: torch.Tensor,
+ profile_logs: torch.Tensor,
+ model_arch: str,
+ compute_kernel_type: str = "bf16",
+) -> None:
+ """
+ Fused operation of down and allreduce.
+
+ Args:
+ vec_in: Input tensor.
+ mat_in: Input tensor.
+ mat_scale: Input tensor.
+ x_in: Input tensor.
+ flag: Input flag.
+ vec_out: Output tensor.
+ profile_logs: Profile logs tensor (1D).
+ model_arch: Model architecture ("deepseek_v3_2" or "glm_5").
+ compute_kernel_type: Compute kernel type ("bf16").
+ """
+ torch.ops.tilert.down_allreduce_op(
+ vec_in,
+ mat_in,
+ mat_scale,
+ x_in,
+ flag,
+ vec_out,
+ model_arch,
+ compute_kernel_type,
+ profile_logs,
+ )
+
+
+class DownAllReduceAlgorithm(Enum):
+ """DownAllReduce algorithm"""
+
+ GENERAL = "general"
+
+
+DownAllReduceWeightsConverter = ExpertDownAllReduceWeightsConverter
+
+
+@dataclass
+class DownAllReduceTilertWeightsAlias:
+ """TileRT weights alias for DownAllReduce."""
+
+ down_weights = "down_weights"
+ down_scales = "down_scales"
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return [self.down_weights, self.down_scales]
+
+ def __call__(self) -> list[str]:
+ return self.tilert_tensor_alias
+
+
+class DownAllReduce(TileRTModule):
+ """DownAllReduce module"""
+
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [DownAllReduceAlgorithm.GENERAL],
+ "glm_5": [DownAllReduceAlgorithm.GENERAL],
+ }
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ device_id: int,
+ num_devices: int,
+ algorithm: DownAllReduceAlgorithm = DownAllReduceAlgorithm.GENERAL,
+ ):
+ super().__init__(
+ self.__class__.__name__,
+ device_id=device_id,
+ model_args=model_args,
+ num_devices=num_devices,
+ )
+
+ self.arch_name = self.model_args.arch_name
+ self.dim = self.model_args.dim
+
+ self.inter_dim = self.model_args.inter_dim
+ self.moe_inter_dim = self.model_args.moe_inter_dim
+ self.moe_inter_dim_per_device = self.moe_inter_dim // self.num_devices
+ self.inter_dim_per_device = self.inter_dim // self.num_devices
+ self.n_experts: int = self.inter_dim_per_device // self.moe_inter_dim_per_device
+ self.block_size = self.model_args.block_size
+ self.dim_scale_dim = self.dim // self.block_size
+ self.in_scale_dim = self.inter_dim // self.block_size
+ self.moe_inter_scale_dim_per_device = self.moe_inter_dim_per_device // self.block_size
+ self.algorithm = algorithm
+
+ if self.arch_name in ("deepseek_v3_2", "glm_5"):
+ self.compute_kernel_type = "bf16"
+ else:
+ raise ValueError(f"Unsupported architecture: {self.arch_name}")
+
+ self.model_arch = self.arch_name
+
+ self.ref_down: torch.Tensor | None = None
+
+ self.tilert_weights: torch.Tensor | None = None
+ self.tilert_scales: torch.Tensor | None = None
+
+ self.hidden_out: torch.Tensor | None = None
+
+ self.profile_logs: torch.Tensor | None = None
+ self.is_init = False
+
+ self.tilert_weights_alias = DownAllReduceTilertWeightsAlias()
+
+ self.tensor_alias: list[str] = [
+ "down_weights",
+ "down_scales",
+ ]
+
+ self.ref_tensor_alias: list[str] = [
+ "mlp.down_proj.weight",
+ "mlp.down_proj.weight_scale_inv",
+ ]
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return self.tilert_weights_alias.tilert_tensor_alias
+
+ def get_weights_list(self) -> list[torch.Tensor]:
+ """
+ Get the weights list.
+
+ Returns:
+ List of weights.
+ """
+ return [self.tilert_weights, self.tilert_scales]
+
+ def device_sharding(
+ self,
+ weights_dict: dict[str, torch.Tensor],
+ key_prefix: str,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Device sharding.
+
+ Args:
+ weights_dict: Dictionary of weights.
+ key_prefix: Key prefix.
+ Returns:
+ Tuple of weights.
+ """
+ down_proj_weight_key = f"{key_prefix}.down_proj.weight"
+ down_proj_scale_key = f"{key_prefix}.down_proj.weight_scale_inv"
+ down_proj_weight = weights_dict[down_proj_weight_key]
+ down_proj_scale = weights_dict[down_proj_scale_key]
+ down_proj_weight = down_proj_weight.reshape(
+ self.dim, self.n_experts, self.num_devices, self.moe_inter_dim_per_device
+ )
+ down_proj_weight_splited = torch.split(down_proj_weight, 1, dim=2)
+
+ down_proj_weight_splited = [
+ down_proj_weight_splited[i]
+ .reshape(self.dim, self.n_experts, self.moe_inter_dim_per_device)
+ .transpose(0, 1)
+ .contiguous()
+ for i in range(self.num_devices)
+ ]
+
+ down_proj_scale = down_proj_scale.reshape(
+ self.dim_scale_dim,
+ self.n_experts,
+ self.num_devices,
+ self.moe_inter_scale_dim_per_device,
+ )
+ down_proj_scale_splited = torch.split(down_proj_scale, 1, dim=2)
+ down_proj_scale_splited = [
+ down_proj_scale_splited[i]
+ .reshape(self.dim_scale_dim, self.n_experts, self.moe_inter_scale_dim_per_device)
+ .transpose(0, 1)
+ .contiguous()
+ for i in range(self.num_devices)
+ ]
+ down_weights = torch.stack(down_proj_weight_splited, dim=0)
+ down_scales = torch.stack(down_proj_scale_splited, dim=0)
+ return down_weights.contiguous(), down_scales.contiguous()
+
+ def init_reference_weights(
+ self,
+ state_dict: dict[str, torch.Tensor],
+ key_prefix: str,
+ device_id: int = 0,
+ ) -> None:
+ """
+ Initialize the reference weights.
+
+ Args:
+ state_dict: State dictionary.
+ device_id: Device ID.
+ """
+ sharded_list = self.device_sharding(state_dict, key_prefix)
+
+ down_weights = sharded_list[0][device_id]
+ down_scales = sharded_list[1][device_id]
+
+ down_list = [
+ weight_dequant(down_weight, down_scale)
+ for down_weight, down_scale in zip(down_weights, down_scales)
+ ]
+ self.ref_down = torch.stack(down_list, dim=0)
+
+ def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ """
+ Initialize the tilert weights.
+
+ Args:
+ state_dict: State dictionary.
+ """
+ assert self.algorithm is not None, "Algorithm is not set"
+ self.tilert_weights, self.tilert_scales = DownAllReduceWeightsConverter(
+ self.model_args, self.num_devices
+ ).dispatch(self.algorithm, [state_dict[alias] for alias in self.tensor_alias])
+
+ def init_tilert_vars(self, batch_size: int, seq_len: int, device_id: int = 0) -> None:
+ """
+ Initialize the tilert variables.
+
+ Args:
+ batch_size: Batch size.
+ seq_len: Sequence length.
+ """
+ self.hidden_out = torch.zeros(
+ (batch_size, seq_len, self.dim),
+ dtype=torch.bfloat16,
+ device=f"cuda:{device_id}",
+ )
+ self.profile_logs = get_profile_log_tensor(device=f"cuda:{device_id}")
+ self.is_init = True
+
+ def init_random_weights(self, device_id: int = 0) -> None:
+ """Initialize the random weights."""
+ scale_dtype = torch.float32 if self.arch_name == "glm_5" else torch.bfloat16
+ down_weights = torch.randn(
+ self.dim, self.inter_dim, dtype=torch.bfloat16, device=f"cuda:{device_id}"
+ ).to(torch.float8_e4m3fn)
+
+ inter_dim_scale_dim = self.inter_dim // self.block_size
+ dim_scale_dim = self.dim // self.block_size
+ down_scales = torch.randn(
+ dim_scale_dim, inter_dim_scale_dim, dtype=scale_dtype, device=f"cuda:{device_id}"
+ )
+ tensor_list = [
+ down_weights,
+ down_scales,
+ ]
+ state_dict = dict(zip(self.ref_tensor_alias, tensor_list))
+
+ self.init_reference_weights(state_dict, "mlp", device_id)
+ sharded_list = self.device_sharding(state_dict, "mlp")
+
+ sharded_state_dict = {
+ alias: sharded_list[i][device_id] for i, alias in enumerate(self.tensor_alias)
+ }
+ self.init_tilert_weights(sharded_state_dict)
+
+ def golden_forward(
+ self,
+ vec_in: torch.Tensor,
+ ) -> torch.Tensor:
+ """
+ Forward pass for the down-project module.
+
+ Args:
+ vec_in: Input vector.
+
+ Returns:
+ Output tensor.
+ """
+ assert self.ref_down is not None
+ bsz = vec_in.shape[0]
+ assert bsz == 1
+ seq_len = vec_in.shape[1]
+ hidden_out_list = []
+ for s in range(seq_len):
+ hidden_out_w2_list = []
+ for i in range(self.n_experts):
+ hidden_out_w2_sel = vec_in[0, s, i].float() @ self.ref_down[i].float().T
+ hidden_out_w2_list.append(hidden_out_w2_sel)
+ hidden_out_w2 = torch.stack(hidden_out_w2_list, dim=0).to(torch.bfloat16)
+ hidden_out_w2 = torch.sum(hidden_out_w2, dim=0)
+ hidden_out_list.append(hidden_out_w2)
+ return torch.stack(hidden_out_list, dim=0)[None, ...]
+
+ def tilert_forward(
+ self,
+ vec_in: torch.Tensor,
+ x_in: torch.Tensor,
+ flag: int,
+ ) -> torch.Tensor:
+ assert self.hidden_out is not None
+ down_allreduce(
+ vec_in,
+ self.tilert_weights,
+ self.tilert_scales,
+ x_in,
+ flag,
+ self.hidden_out,
+ self.profile_logs,
+ self.model_arch,
+ self.compute_kernel_type,
+ )
+ return self.hidden_out
+
+ def __call__(
+ self,
+ x_in: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.golden_forward(x_in)
diff --git a/python/models/deepseek_v3_2/ops/eh_proj_allreduce.py b/tilert/models/glm_5/_dsa_v32/ops/eh_proj_allreduce.py
similarity index 87%
rename from python/models/deepseek_v3_2/ops/eh_proj_allreduce.py
rename to tilert/models/glm_5/_dsa_v32/ops/eh_proj_allreduce.py
index 309751a..fe0b71f 100644
--- a/python/models/deepseek_v3_2/ops/eh_proj_allreduce.py
+++ b/tilert/models/glm_5/_dsa_v32/ops/eh_proj_allreduce.py
@@ -3,11 +3,10 @@
from dataclasses import dataclass
from enum import Enum
-# import torch.nn.functional as F
import torch
from tilert.models.base import TileRTModule, TilertWeightsConverter
-from tilert.models.deepseek_v3_2.model_args import ModelArgs
+from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
from tilert.utils import get_profile_log_tensor
__all__ = [
@@ -23,6 +22,7 @@ def eh_proj_allreduce(
flag: int,
vec_out: torch.Tensor,
profile_logs: torch.Tensor,
+ model_arch: str,
) -> None:
"""
Fused operation of EHProj and allreduce.
@@ -33,19 +33,21 @@ def eh_proj_allreduce(
w_eh: Input tensor of shape (7168, 1792) or (128, 7, 56, 256).
flag: Input tensor.
vec_out: Output tensor of shape (1, seq_len, 7168).
- profile_logs: Profile logs tensor. This is a 1D tensor of shape
- (num_sms,) to store the profile logs of the eh_proj_allreduce
- operation, where num_sms is the number of SMs on the
- device.
+ profile_logs: Profile logs tensor (1D).
+ model_arch: Model architecture string.
"""
- dim = vec_in_enorm.shape[-1]
- if dim == 7168:
- func_call = torch.ops.tilert.eh_proj_allreduce_op
- elif dim == 6144:
- func_call = torch.ops.tilert.eh_proj_allreduce_glm5_op
- else:
- raise ValueError(f"Unsupported dimension: {dim}")
- func_call(vec_in_enorm, vec_in_hnorm, w_eh, flag, vec_out, profile_logs)
+ compute_kernel_type = "bf16"
+ torch.ops.tilert.eh_proj_allreduce_op(
+ vec_in_enorm,
+ vec_in_hnorm,
+ w_eh,
+ flag,
+ vec_out,
+ profile_logs,
+ model_arch,
+ compute_kernel_type,
+ torch.empty(0, dtype=torch.int64, device=vec_in_enorm.device),
+ )
class EHProjAllReduceAlgorithm(Enum):
@@ -100,6 +102,11 @@ def __call__(self) -> list[str]:
class EHProjAllReduce(TileRTModule):
"""EHProjAllReduce module"""
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [EHProjAllReduceAlgorithm.GENERAL],
+ "glm_5": [EHProjAllReduceAlgorithm.GENERAL],
+ }
+
def __init__(
self,
model_args: ModelArgs,
@@ -117,13 +124,10 @@ def __init__(
self.algorithm = algorithm
- # reference weights
self.ref_proj: torch.Tensor | None = None
- # tilert weights
self.tilert_proj: torch.Tensor | None = None
- # tilert vars
self.hidden_out: torch.Tensor | None = None
self.profile_logs: torch.Tensor | None = None
@@ -131,13 +135,10 @@ def __init__(
self.tilert_weights_alias = EHProjAllReduceTilertWeightsAlias()
- # for device sharding, corresponding to the output of device_sharding
- # and input of tilert_forward
self.tensor_alias: list[str] = [
"eh_proj_weights",
]
- # reference tensor aliases
self.ref_tensor_alias: list[str] = [
"eh_proj.weight",
]
@@ -158,7 +159,7 @@ def get_weights_list(self) -> list[torch.Tensor]:
def device_sharding(
self,
weights_dict: dict[str, torch.Tensor],
- key_prefix: str | None = None, # e.g. model.layers.{layer_id}
+ key_prefix: str | None = None,
) -> tuple[torch.Tensor]:
"""
Device sharding.
@@ -220,7 +221,6 @@ def init_tilert_vars(self, batch_size: int, seq_len: int, device_id: int = 0) ->
batch_size: Batch size.
seq_len: Sequence length.
"""
- # tilert vars
self.hidden_out = torch.zeros(
(batch_size, seq_len, self.dim),
dtype=torch.bfloat16,
@@ -282,6 +282,12 @@ def tilert_forward(
) -> torch.Tensor:
assert self.hidden_out is not None
eh_proj_allreduce(
- vec_in_enorm, vec_in_hnorm, self.tilert_proj, flag, self.hidden_out, self.profile_logs
+ vec_in_enorm,
+ vec_in_hnorm,
+ self.tilert_proj,
+ flag,
+ self.hidden_out,
+ self.profile_logs,
+ model_arch=self.model_args.arch_name,
)
return self.hidden_out
diff --git a/python/models/deepseek_v3_2/ops/expert_down_allreduce.py b/tilert/models/glm_5/_dsa_v32/ops/expert_down_allreduce.py
similarity index 78%
rename from python/models/deepseek_v3_2/ops/expert_down_allreduce.py
rename to tilert/models/glm_5/_dsa_v32/ops/expert_down_allreduce.py
index d49bc77..b0e6b24 100644
--- a/python/models/deepseek_v3_2/ops/expert_down_allreduce.py
+++ b/tilert/models/glm_5/_dsa_v32/ops/expert_down_allreduce.py
@@ -1,4 +1,5 @@
-from collections.abc import Callable
+"""ExpertDownAllreduce operation module."""
+
from dataclasses import dataclass
from enum import Enum
@@ -6,12 +7,11 @@
from tilert.models.base import TileRTModule, TilertWeightsConverter
from tilert.models.common import weight_dequant
-from tilert.models.deepseek_v3_2.model_args import ModelArgs
+from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
from tilert.utils import get_profile_log_tensor
__all__ = [
"expert_down_allreduce",
- "expert_down_allreduce_glm5",
"ExpertDownAllReduceAlgorithm",
"ExpertDownAllReduce",
"ExpertDownAllReduceTilertWeightsAlias",
@@ -31,53 +31,36 @@ def expert_down_allreduce(
flag: int,
vec_out: torch.Tensor,
profile_logs: torch.Tensor,
+ model_arch: str,
+ compute_kernel_type: str = "bf16",
) -> None:
"""
- Fused expert down + allreduce (deepseek_v3_2).
+ Fused expert down + allreduce (unified for DSv32 and GLM5).
Args:
vec_in: [1, seq_len, n_experts, 256], bfloat16.
- mat_in: [n_experts, 6144, 256], float8_e4m3fn.
- mat_scale: [n_experts, 1024, 2], bfloat16.
+ mat_in: [n_experts, dim, 256], float8_e4m3fn.
+ mat_scale: [n_experts, 1024, 2], bfloat16 (DSv32) or float32 (GLM5).
indices: [1, seq_len, 8], int32.
scores: [1, seq_len, 8], float32.
- x_in: [1, seq_len, 6144], bfloat16.
+ x_in: [1, seq_len, dim], bfloat16.
flag: User flag.
- vec_out: [1, seq_len, 6144], bfloat16 (output).
- profile_logs: 1D tensor (num_sms,) for profile logs.
+ vec_out: [1, seq_len, dim], bfloat16 (output).
+ profile_logs: 1D tensor for profile logs.
+ compute_kernel_type: "bf16".
"""
torch.ops.tilert.expert_down_allreduce_op(
- vec_in, mat_in, mat_scale, indices, scores, x_in, flag, vec_out, profile_logs
- )
-
-
-def expert_down_allreduce_glm5(
- vec_in: torch.Tensor,
- mat_in: torch.Tensor,
- mat_scale: torch.Tensor,
- indices: torch.Tensor,
- scores: torch.Tensor,
- x_in: torch.Tensor,
- flag: int,
- vec_out: torch.Tensor,
- profile_logs: torch.Tensor,
-) -> None:
- """
- Fused expert down + allreduce (glm_5).
-
- Args:
- vec_in: [1, seq_len, n_experts, 256], bfloat16.
- mat_in: [n_experts, 6144, 256], float8_e4m3fn.
- mat_scale: [n_experts, 1024, 2], bfloat16.
- indices: [1, seq_len, 8], int32.
- scores: [1, seq_len, 8], float32.
- x_in: [1, seq_len, 6144], bfloat16.
- flag: User flag.
- vec_out: [1, seq_len, 6144], bfloat16 (output).
- profile_logs: 1D tensor (num_sms,) for profile logs.
- """
- torch.ops.tilert.expert_down_allreduce_glm5_op(
- vec_in, mat_in, mat_scale, indices, scores, x_in, flag, vec_out, profile_logs
+ vec_in,
+ mat_in,
+ mat_scale,
+ indices,
+ scores,
+ x_in,
+ flag,
+ vec_out,
+ profile_logs,
+ model_arch,
+ compute_kernel_type,
)
@@ -115,36 +98,51 @@ def convert_to_general(
num_sms = 128
dim_per_sm = dim // num_sms
dim_scale_dim = dim // args.block_size
+ expert_dim = args.moe_inter_dim // 8
+ k_chunks = expert_dim // 32
+ scale_cols = expert_dim // args.block_size
with torch.inference_mode():
mat_in, scale_in = weights_list
exp_num = mat_in.shape[0]
- mat_in_s = mat_in.reshape(exp_num, num_sms, dim_per_sm, 256)
- mat_in_0 = mat_in_s[:, :, :16].reshape(exp_num, num_sms, 16, 8, 32).transpose(2, 3)
+ mat_in_s = mat_in.reshape(exp_num, num_sms, dim_per_sm, expert_dim)
+ mat_in_0 = (
+ mat_in_s[:, :, :16].reshape(exp_num, num_sms, 16, k_chunks, 32).transpose(2, 3)
+ )
mat_in_0 = self._swizzle_qmma_16x32(mat_in_0).reshape(exp_num, 128, -1)
- mat_in_1 = mat_in_s[:, :, 16:32].reshape(exp_num, num_sms, 16, 8, 32).transpose(2, 3)
+ mat_in_1 = (
+ mat_in_s[:, :, 16:32].reshape(exp_num, num_sms, 16, k_chunks, 32).transpose(2, 3)
+ )
mat_in_1 = self._swizzle_qmma_16x32(mat_in_1).reshape(exp_num, 128, -1)
- mat_in_2 = mat_in_s[:, :, 32:48].reshape(exp_num, num_sms, 16, 8, 32).transpose(2, 3)
+ mat_in_2 = (
+ mat_in_s[:, :, 32:48].reshape(exp_num, num_sms, 16, k_chunks, 32).transpose(2, 3)
+ )
mat_in_2 = self._swizzle_qmma_16x32(mat_in_2).reshape(exp_num, 128, -1)
- mat_in_swizzled = torch.cat([mat_in_0, mat_in_1, mat_in_2], dim=2)
+ mats_to_cat = [mat_in_0, mat_in_1, mat_in_2]
if arch_name == "deepseek_v3_2":
- mat_in_3 = mat_in_s[:, :, 48:56].reshape(exp_num, num_sms, 8, 8, 32).transpose(2, 3)
+ mat_in_3 = (
+ mat_in_s[:, :, 48:56].reshape(exp_num, num_sms, 8, k_chunks, 32).transpose(2, 3)
+ )
mat_in_3 = self._swizzle_qmma_8x32(mat_in_3).reshape(exp_num, 128, -1)
- mat_in_swizzled = torch.cat([mat_in_0, mat_in_1, mat_in_2, mat_in_3], dim=2)
- mat_in_swizzled = mat_in_swizzled.reshape(exp_num, dim, 256)
+ mats_to_cat.append(mat_in_3)
+ mat_in_swizzled = torch.cat(mats_to_cat, dim=2)
+ mat_in_swizzled = mat_in_swizzled.reshape(exp_num, dim, expert_dim)
mat_scale_tilert = (
- scale_in.reshape(exp_num, dim_scale_dim, 1, 2)
+ scale_in.reshape(exp_num, dim_scale_dim, 1, scale_cols)
.repeat(1, 1, 16, 1)
.reshape(exp_num, num_sms, -1)
)
- padding_zeros = torch.zeros(
- (exp_num, num_sms, 16 - mat_scale_tilert.shape[-1]),
- dtype=scale_in.dtype,
- device=scale_in.device,
- )
- mat_scale_tilert = torch.cat([mat_scale_tilert, padding_zeros], dim=2)
- mat_scale_tilert = mat_scale_tilert.reshape(exp_num, 1024, 2)
+ target_cols_per_sm = 1024 * scale_cols // num_sms
+ pad_amount = target_cols_per_sm - mat_scale_tilert.shape[-1]
+ if pad_amount > 0:
+ padding_zeros = torch.zeros(
+ (exp_num, num_sms, pad_amount),
+ dtype=scale_in.dtype,
+ device=scale_in.device,
+ )
+ mat_scale_tilert = torch.cat([mat_scale_tilert, padding_zeros], dim=2)
+ mat_scale_tilert = mat_scale_tilert.reshape(exp_num, 1024, scale_cols)
if arch_name == "glm_5":
if mat_scale_tilert.dtype != torch.float32:
print(
@@ -153,7 +151,7 @@ def convert_to_general(
+ "is not float32, convert to float32."
)
mat_scale_tilert = mat_scale_tilert.to(torch.float32)
- else: # DS v3.2, use bfloat16 for mat_scale_tilert
+ else:
mat_scale_tilert = mat_scale_tilert.to(torch.bfloat16)
return mat_in_swizzled.contiguous(), mat_scale_tilert.contiguous()
@@ -176,6 +174,11 @@ def __call__(self) -> list[str]:
class ExpertDownAllReduce(TileRTModule):
"""ExpertDownAllReduce module."""
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [ExpertDownAllReduceAlgorithm.GENERAL],
+ "glm_5": [ExpertDownAllReduceAlgorithm.GENERAL],
+ }
+
def __init__(
self,
model_args: ModelArgs,
@@ -204,15 +207,14 @@ def __init__(
self.hidden_out: torch.Tensor | None = None
self.profile_logs: torch.Tensor | None = None
self.is_init = False
- self.exp_down_allreduce_func: Callable | None = None
- if self.arch_name == "deepseek_v3_2":
- self.exp_down_allreduce_func = expert_down_allreduce
- elif self.arch_name == "glm_5":
- self.exp_down_allreduce_func = expert_down_allreduce_glm5
+ if self.arch_name in ("deepseek_v3_2", "glm_5"):
+ self.compute_kernel_type = "bf16"
else:
raise ValueError(f"Unsupported architecture: {self.arch_name}")
+ self.model_arch = self.arch_name
+
self.tilert_weights_alias = ExpertDownAllReduceTilertWeightsAlias()
self.tensor_alias = ["exp_down_weights", "exp_down_scales"]
self.ref_tensor_alias = (
@@ -316,24 +318,21 @@ def init_tilert_vars(self, batch_size: int, seq_len: int, device_id: int = 0) ->
self.is_init = True
def init_random_weights(self, device_id: int = 0) -> None:
- down_weights = [
- torch.randn(
- self.dim, self.moe_inter_dim, dtype=torch.bfloat16, device=f"cuda:{device_id}"
- ).to(torch.float8_e4m3fn)
- for _ in range(self.n_routed_experts + 1)
- ]
+ n = self.n_routed_experts + 1
+ dev = f"cuda:{device_id}"
+ down_weights = list(
+ torch.randn(n, self.dim, self.moe_inter_dim, dtype=torch.bfloat16, device=dev)
+ .to(torch.float8_e4m3fn)
+ .unbind(0)
+ )
dim_scale_dim = self.dim // self.block_size
moe_inter_dim_scale_dim = self.moe_inter_dim // self.block_size
scale_dtype = torch.float32 if self.arch_name == "glm_5" else torch.bfloat16
- down_scales = [
+ down_scales = list(
torch.randn(
- dim_scale_dim,
- moe_inter_dim_scale_dim,
- dtype=scale_dtype,
- device=f"cuda:{device_id}",
- )
- for _ in range(self.n_routed_experts + 1)
- ]
+ n, dim_scale_dim, moe_inter_dim_scale_dim, dtype=scale_dtype, device=dev
+ ).unbind(0)
+ )
state_dict = dict(
zip(
self.ref_tensor_alias,
@@ -367,6 +366,7 @@ def golden_forward(
hidden_out_w2_list.append(hidden_out_w2_sel * scores[0, s, i])
hidden_out_w2 = torch.stack(hidden_out_w2_list, dim=0).to(torch.bfloat16)
hidden_out_w2 = torch.sum(hidden_out_w2, dim=0)
+
hidden_out_list.append(hidden_out_w2)
hidden_out = torch.stack(hidden_out_list, dim=0)
return hidden_out[None, ...]
@@ -379,9 +379,8 @@ def tilert_forward(
x_in: torch.Tensor,
flag: int,
) -> torch.Tensor:
- assert self.exp_down_allreduce_func is not None
assert self.hidden_out is not None
- self.exp_down_allreduce_func(
+ expert_down_allreduce(
vec_in,
self.tilert_weights,
self.tilert_scales,
@@ -391,6 +390,8 @@ def tilert_forward(
flag,
self.hidden_out,
self.profile_logs,
+ self.model_arch,
+ self.compute_kernel_type,
)
return self.hidden_out
diff --git a/python/models/deepseek_v3_2/ops/expert_sel_up_gate_silu.py b/tilert/models/glm_5/_dsa_v32/ops/expert_sel_up_gate_silu.py
similarity index 86%
rename from python/models/deepseek_v3_2/ops/expert_sel_up_gate_silu.py
rename to tilert/models/glm_5/_dsa_v32/ops/expert_sel_up_gate_silu.py
index 50a0a67..e2d96eb 100644
--- a/python/models/deepseek_v3_2/ops/expert_sel_up_gate_silu.py
+++ b/tilert/models/glm_5/_dsa_v32/ops/expert_sel_up_gate_silu.py
@@ -4,15 +4,12 @@
from enum import Enum
import numpy as np
-
-# from typing import Any
import torch
import torch.nn.functional as F
from tilert.models.base import TileRTModule, TilertWeightsConverter
from tilert.models.common import weight_dequant
-from tilert.models.deepseek_v3_2.model_args import ModelArgs
-from tilert.profiler.utils import parse_profile_log_tensor
+from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
from tilert.utils import get_profile_log_tensor
__all__ = [
@@ -34,9 +31,11 @@ def expert_select_up_gate_silu(
expert_indices_out: torch.Tensor,
profile_logs: torch.Tensor,
algorithm: str = "fp8mma",
+ *,
+ model_arch: str,
) -> None:
"""Expert SelectUpGateSiLU operation."""
- args_list = [
+ torch.ops.tilert.expert_select_up_gate_silu_op(
hidden_in,
scores_in,
bias_in,
@@ -45,9 +44,9 @@ def expert_select_up_gate_silu(
expert_probs_out,
expert_indices_out,
profile_logs,
+ model_arch,
algorithm,
- ]
- torch.ops.tilert.expert_select_up_gate_silu_op(*args_list)
+ )
@dataclass
@@ -114,7 +113,6 @@ class ExpertSelectUpGateSiLUWeightsConverter(TilertWeightsConverter):
def _swizzle_qmma_16x32(mat_in: torch.Tensor) -> torch.Tensor:
assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 32
assert mat_in.dtype == torch.float8_e4m3fn
- # PTX isa fig.88
pre_shape = mat_in.shape[:-2]
mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 4).transpose(-4, -3).transpose(-5, -4)
return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 4).transpose(-3, -2)
@@ -122,7 +120,6 @@ def _swizzle_qmma_16x32(mat_in: torch.Tensor) -> torch.Tensor:
@staticmethod
def _swizzle_mma_16x32(mat_in: torch.Tensor) -> torch.Tensor:
assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 32
- # PTX isa fig.88
pre_shape = mat_in.shape[:-2]
mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 4).transpose(-4, -3).transpose(-5, -4)
return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 4).transpose(-3, -2)
@@ -130,7 +127,6 @@ def _swizzle_mma_16x32(mat_in: torch.Tensor) -> torch.Tensor:
@staticmethod
def _swizzle_mma_16x16(mat_in: torch.Tensor) -> torch.Tensor:
assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 16
- # PTX isa fig.88
pre_shape = mat_in.shape[:-2]
mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 2).transpose(-4, -3).transpose(-5, -4)
return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 2).transpose(-3, -2)
@@ -155,7 +151,6 @@ def tilert_to_tilert_144sm(
weights_trt = mat_in.reshape(exp_num, 128, 4, 7168)
weights_w1 = weights_trt[:, :, :2].reshape(exp_num, 256, 7168)
weights_w3 = weights_trt[:, :, 2:].reshape(exp_num, 256, 7168)
- # to 16x1024 blocks
weights_w1 = weights_w1.reshape(exp_num, 16, 16, 7, 1024).transpose(2, 3)
weights_w3 = weights_w3.reshape(exp_num, 16, 16, 7, 1024).transpose(2, 3)
if mma_type == "16x32":
@@ -177,7 +172,6 @@ def tilert_to_tilert_144sm(
assert weights.shape == (exp_num, 16, 7, 32, 1024)
weights = weights.reshape(exp_num, 16, 7, 32 * 1024)
- # For scales, first unswizzle
scales_unswizzled = torch.zeros(exp_num, 4, 56)
for i in range(64):
if ((i % 8) * 8 + i // 8) < 56:
@@ -220,62 +214,63 @@ def tilert_to_tilert_144sm_mma(
def convert_to_mma(
self, weights_list: list[torch.Tensor], algorithm: str = "fp8mma"
) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Convert the weights to mma format.
-
- Args:
- weights: List of weights.
-
- Returns:
- Tuple of weights.
- """
+ """Convert the weights to mma format."""
args = self.model_args
dim = args.dim
- pages = dim // 1024 # 6 for GLM5, 7 for DS v3.2
+ pages = dim // 1024
dim_scale_dim = dim // args.block_size
with torch.inference_mode():
- # w1: gate, w3: up
bias_or_gamma, weights_w1, scales_w1, weights_w3, scales_w3 = weights_list
exp_num = weights_w1.shape[0]
- # to 16x1024 blocks
- weights_w1 = weights_w1.reshape(exp_num, 16, 16, pages, 1024).transpose(2, 3)
- weights_w3 = weights_w3.reshape(exp_num, 16, 16, pages, 1024).transpose(2, 3)
- # to 16x32 blocks and swizzle
+ moe_rows = weights_w1.shape[1]
+ n_row_groups = moe_rows // 16
+ scale_m_dim = moe_rows // args.block_size
+ weights_w1 = weights_w1.reshape(exp_num, n_row_groups, 16, pages, 1024).transpose(2, 3)
+ weights_w3 = weights_w3.reshape(exp_num, n_row_groups, 16, pages, 1024).transpose(2, 3)
if algorithm == "fp8mma":
- weights_w1 = weights_w1.reshape(exp_num, 16, pages, 16, 32, 32).transpose(3, 4)
+ weights_w1 = weights_w1.reshape(exp_num, n_row_groups, pages, 16, 32, 32).transpose(
+ 3, 4
+ )
weights_w1 = self._swizzle_qmma_16x32(weights_w1)
- weights_w1 = weights_w1.reshape(exp_num, 16, pages, 16, 1024)
- weights_w3 = weights_w3.reshape(exp_num, 16, pages, 16, 32, 32).transpose(3, 4)
+ weights_w1 = weights_w1.reshape(exp_num, n_row_groups, pages, 16, 1024)
+ weights_w3 = weights_w3.reshape(exp_num, n_row_groups, pages, 16, 32, 32).transpose(
+ 3, 4
+ )
weights_w3 = self._swizzle_qmma_16x32(weights_w3)
- weights_w3 = weights_w3.reshape(exp_num, 16, pages, 16, 1024)
+ weights_w3 = weights_w3.reshape(exp_num, n_row_groups, pages, 16, 1024)
elif algorithm == "fp16mma":
- weights_w1 = weights_w1.reshape(exp_num, 16, pages, 16, 64, 16).transpose(3, 4)
+ weights_w1 = weights_w1.reshape(exp_num, n_row_groups, pages, 16, 64, 16).transpose(
+ 3, 4
+ )
weights_w1 = self._swizzle_mma_16x16(weights_w1)
- weights_w1 = weights_w1.reshape(exp_num, 16, pages, 16, 1024)
- weights_w3 = weights_w3.reshape(exp_num, 16, pages, 16, 64, 16).transpose(3, 4)
+ weights_w1 = weights_w1.reshape(exp_num, n_row_groups, pages, 16, 1024)
+ weights_w3 = weights_w3.reshape(exp_num, n_row_groups, pages, 16, 64, 16).transpose(
+ 3, 4
+ )
weights_w3 = self._swizzle_mma_16x16(weights_w3)
- weights_w3 = weights_w3.reshape(exp_num, 16, pages, 16, 1024)
+ weights_w3 = weights_w3.reshape(exp_num, n_row_groups, pages, 16, 1024)
else:
raise ValueError(f"Unsupported algorithm: {algorithm}")
- # concat w1 and w3
weights: torch.Tensor = torch.cat([weights_w1, weights_w3], dim=3)
- assert weights.shape == (exp_num, 16, pages, 32, 1024)
- weights = weights.reshape(exp_num, 16, pages, 32 * 1024)
+ assert weights.shape == (exp_num, n_row_groups, pages, 32, 1024)
+ weights = weights.reshape(exp_num, n_row_groups, pages, 32 * 1024)
+ scales_per_page = 1024 // args.block_size
+ repeat_factor = n_row_groups // scale_m_dim
scales_w1 = (
- scales_w1.reshape(exp_num, 2, 1, dim_scale_dim)
- .repeat(1, 1, 8, 1)
- .reshape(exp_num, 16, 1, pages, 8)
+ scales_w1.reshape(exp_num, scale_m_dim, 1, dim_scale_dim)
+ .repeat(1, 1, repeat_factor, 1)
+ .reshape(exp_num, n_row_groups, 1, pages, scales_per_page)
)
scales_w1 = scales_w1.transpose(2, 3)
scales_w3 = (
- scales_w3.reshape(exp_num, 2, 1, dim_scale_dim)
- .repeat(1, 1, 8, 1)
- .reshape(exp_num, 16, 1, pages, 8)
+ scales_w3.reshape(exp_num, scale_m_dim, 1, dim_scale_dim)
+ .repeat(1, 1, repeat_factor, 1)
+ .reshape(exp_num, n_row_groups, 1, pages, scales_per_page)
)
scales_w3 = scales_w3.transpose(2, 3)
scales = torch.cat([scales_w1, scales_w3], dim=3)
- assert scales.shape == (exp_num, 16, pages, 2, 8)
+ assert scales.shape == (exp_num, n_row_groups, pages, 2, scales_per_page)
if self.model_args.arch_name == "glm_5":
if scales.dtype != torch.float32:
@@ -285,14 +280,16 @@ def convert_to_mma(
+ "is not float32, convert to float32."
)
scales = scales.to(torch.float32)
- else: # DS v3.2, use bfloat16 for scales
+ else:
scales = scales.to(torch.bfloat16)
- scales = scales.reshape(exp_num, 16, pages, 2 * 8).view(dtype=torch.float8_e4m3fn)
+ scales = scales.reshape(exp_num, n_row_groups, pages, 2 * scales_per_page).view(
+ dtype=torch.float8_e4m3fn
+ )
weights_and_scales = torch.zeros(
exp_num,
- 16,
+ n_row_groups,
pages,
32 * 1024 + 128,
dtype=torch.float8_e4m3fn,
@@ -335,6 +332,17 @@ def convert_to_fp16mma(
class ExpertSelectUpGateSiLU(TileRTModule):
"""ExpertSelectUpGateSiLU module"""
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [
+ ExpertSelectUpGateSiLUAlgorithm.FP8MMA,
+ ExpertSelectUpGateSiLUAlgorithm.FP16MMA,
+ ],
+ "glm_5": [
+ ExpertSelectUpGateSiLUAlgorithm.FP8MMA,
+ ExpertSelectUpGateSiLUAlgorithm.FP16MMA,
+ ],
+ }
+
def __init__(
self,
model_args: ModelArgs,
@@ -377,18 +385,14 @@ def __init__(
)
)
- # reference weights
self.ref_bias: torch.Tensor | None = None
self.ref_gate: torch.Tensor | None = None
self.ref_up: torch.Tensor | None = None
- # tilert weights
self.tilert_bias: torch.Tensor | None = None
self.tilert_weights: torch.Tensor | None = None
- # for compatibility, to be removed in the future
self.tilert_scales = torch.zeros(1, dtype=torch.bfloat16, device=torch.device("cuda"))
- # tilert vars
self.hidden_out: torch.Tensor | None = None
self.expert_probs: torch.Tensor | None = None
self.expert_indices: torch.Tensor | None = None
@@ -423,7 +427,7 @@ def get_weights_list(self) -> list[torch.Tensor]:
@staticmethod
def process_gate_up_weights(
- key_prefix: str, # e.g. mlp.shared_experts or mlp.experts.{id}
+ key_prefix: str,
weights_hf: dict[str, torch.Tensor],
num_devices: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -537,17 +541,11 @@ def init_reference_weights(
self.ref_up = torch.stack(ref_up_list, dim=0)
def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
- """
- Initialize the tilert weights.
-
- Args:
- state_dict: State dict keyed by tilert_weights_alias() (per-device).
- """
+ """Initialize the tilert weights."""
assert self.algorithm is not None, "Algorithm is not set"
weights_list = [state_dict[alias] for alias in self.tilert_weights_alias()]
- self.tilert_bias, self.tilert_weights = ExpertSelectUpGateSiLUWeightsConverter(
- self.model_args, self.num_devices
- ).dispatch(self.algorithm, weights_list)
+ converter = ExpertSelectUpGateSiLUWeightsConverter(self.model_args, self.num_devices)
+ self.tilert_bias, self.tilert_weights = converter.dispatch(self.algorithm, weights_list)
def init_tilert_vars(self, batch_size: int, seq_len: int, device: str = "cuda") -> None:
"""
@@ -557,7 +555,6 @@ def init_tilert_vars(self, batch_size: int, seq_len: int, device: str = "cuda")
batch_size: Batch size.
seq_len: Sequence length.
"""
- # tilert vars
self.hidden_out = torch.zeros(
(
batch_size,
@@ -589,30 +586,31 @@ def init_random_weights(self, device: str = "cuda") -> None:
Returns:
None
"""
+ n = self.n_routed_experts + 1
bias = torch.randn(self.n_routed_experts, dtype=torch.float32, device=device)
- gate_weights = [
- torch.randn(self.moe_inter_dim, self.dim, dtype=torch.bfloat16, device=device).to(
- torch.float8_e4m3fn
- )
- for _ in range(self.n_routed_experts + 1)
- ]
- up_weights = [
- torch.randn(self.moe_inter_dim, self.dim, dtype=torch.bfloat16, device=device).to(
- torch.float8_e4m3fn
- )
- for _ in range(self.n_routed_experts + 1)
- ]
+ gate_weights = list(
+ torch.randn(n, self.moe_inter_dim, self.dim, dtype=torch.bfloat16, device=device)
+ .to(torch.float8_e4m3fn)
+ .unbind(0)
+ )
+ up_weights = list(
+ torch.randn(n, self.moe_inter_dim, self.dim, dtype=torch.bfloat16, device=device)
+ .to(torch.float8_e4m3fn)
+ .unbind(0)
+ )
moe_inter_dim_scale_dim = self.moe_inter_dim // self.block_size
dim_scale_dim = self.dim // self.block_size
scale_dtype = torch.float32 if self.arch_name == "glm_5" else torch.bfloat16
- gate_scales = [
- torch.randn(moe_inter_dim_scale_dim, dim_scale_dim, dtype=scale_dtype, device=device)
- for _ in range(self.n_routed_experts + 1)
- ]
- up_scales = [
- torch.randn(moe_inter_dim_scale_dim, dim_scale_dim, dtype=scale_dtype, device=device)
- for _ in range(self.n_routed_experts + 1)
- ]
+ gate_scales = list(
+ torch.randn(
+ n, moe_inter_dim_scale_dim, dim_scale_dim, dtype=scale_dtype, device=device
+ ).unbind(0)
+ )
+ up_scales = list(
+ torch.randn(
+ n, moe_inter_dim_scale_dim, dim_scale_dim, dtype=scale_dtype, device=device
+ ).unbind(0)
+ )
tensor_list = [
bias,
*gate_weights,
@@ -652,7 +650,6 @@ def _ref_expert_select_ds(self, scores: torch.Tensor) -> tuple[torch.Tensor, tor
return weights, indices
def _ref_expert_select_glm5(self, scores: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
- # flatten_dim = np.prod(scores.size()[:-1])
scores = scores.sigmoid()
original_scores = scores
if self.ref_bias is not None:
@@ -682,7 +679,6 @@ def golden_forward(
raise ValueError(f"Unsupported architecture: {self.arch_name}")
hidden_out_list = []
for s in range(seq_len):
- # ref up-gate silu
hidden_out_w1_list = []
hidden_out_w3_list = []
hidden_out_w1_shared = x_in[0, s].float() @ self.ref_gate[0].float().T
@@ -721,9 +717,6 @@ def tilert_forward(
self.expert_indices,
self.profile_logs,
self.algorithm.value,
+ model_arch=self.model_args.arch_name,
)
- if self.flag_enable_profiling_log:
- parse_profile_log_tensor(
- self.profile_logs, self.get_profile_log_path(), [(self.op_name, 0.0)]
- )
return self.hidden_out, self.expert_probs, self.expert_indices
diff --git a/tilert/models/glm_5/_dsa_v32/ops/flash_sparse_mla.py b/tilert/models/glm_5/_dsa_v32/ops/flash_sparse_mla.py
new file mode 100644
index 0000000..1d4cc00
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/ops/flash_sparse_mla.py
@@ -0,0 +1,261 @@
+"""Flash Sparse MLA operation module."""
+
+import math
+from enum import Enum
+
+import torch
+
+from tilert.models.base import TileRTModule
+from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
+from tilert.utils import get_profile_log_tensor
+
+__all__ = [
+ "flash_sparse_mla",
+ "FlashSparseMLACombine",
+]
+
+
+def flash_sparse_mla(
+ query: torch.Tensor,
+ query_pe: torch.Tensor,
+ key_value: torch.Tensor,
+ key_pe: torch.Tensor,
+ indices: torch.Tensor,
+ cur_pos: torch.Tensor,
+ output: torch.Tensor,
+ profile_logs: torch.Tensor,
+ split_size: int = 64,
+ compute_kernel_type: str = "bf16mma",
+ *,
+ model_arch: str,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Flash Sparse MLA operation for GLM5.
+
+ Args:
+ query: Query tensor. (bs, seqlen, heads, dim)
+ query_pe: Query position embedding tensor. (bs, seqlen, heads, pe_dim)
+ key_value: Key-value tensor. (bs, seqlen_kv, dim)
+ key_pe: Key position embedding tensor. (bs, seqlen_kv, pe_dim)
+ indices: Indices tensor. (bs, seqlen, topk)
+ cur_pos: cur_pos tensor. (1)
+ output: Output tensor.
+ profile_logs: Profile logs tensor.
+ split_size: Number of splits.
+ """
+ batch, seqlen, heads, hidden_dim = query.shape
+ if split_size != 64:
+ raise ValueError(
+ "The current implementation of flash_sparse_mla_op only supports split_size=64"
+ )
+ if batch != 1:
+ raise ValueError("The current implementation of flash_sparse_mla_op only supports batch=1")
+ if seqlen > 4:
+ raise ValueError(
+ "The current implementation of flash_sparse_mla_op only supports seqlen<=4"
+ )
+
+ seqlen_kv = key_value.shape[1]
+ index_len = indices.shape[-1]
+ if index_len > seqlen_kv:
+ raise ValueError("index_len must be less than or equal to seqlen_kv")
+
+ device = query.device
+ acc_type = torch.float32
+
+ dim = key_value.shape[-1]
+ max_num_splits = 32
+
+ lse = torch.empty((batch, seqlen, heads), device=device, dtype=acc_type)
+ lse_acc = torch.empty((batch, seqlen, heads, max_num_splits), device=device, dtype=acc_type)
+ output_acc = torch.empty(
+ batch, seqlen, heads, max_num_splits, dim, device=device, dtype=acc_type
+ )
+
+ if heads not in (8, 10, 16, 20):
+ raise ValueError(f"Unsupported heads: {heads}")
+ torch.ops.tilert.flash_sparse_mla_op(
+ query,
+ query_pe,
+ key_value,
+ key_pe,
+ indices,
+ cur_pos,
+ output,
+ output_acc,
+ lse,
+ lse_acc,
+ split_size,
+ model_arch,
+ compute_kernel_type,
+ profile_logs,
+ torch.empty(0, dtype=torch.int64, device=query.device),
+ )
+ return lse, lse_acc, output_acc
+
+
+class FlashSparseMLACombineAlgorithm(Enum):
+ """FlashSparseMLACombine algorithm."""
+
+ BF16MMA = "bf16mma"
+
+
+class FlashSparseMLACombine(TileRTModule):
+ """Flash Sparse MLA combine module; no weights, uses model_args for scale and config."""
+
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [FlashSparseMLACombineAlgorithm.BF16MMA],
+ "glm_5": [FlashSparseMLACombineAlgorithm.BF16MMA],
+ }
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ num_devices: int,
+ layer_idx: int = 0,
+ ):
+ super().__init__(
+ type(self).__name__,
+ model_args=model_args,
+ num_devices=num_devices,
+ layer_idx=layer_idx,
+ )
+ self.tilert_tensor_alias: list[str] = []
+ self.ref_tensor_alias: list[str] = []
+
+ scale = (model_args.qk_nope_head_dim + model_args.qk_rope_head_dim) ** -0.5
+ if model_args.rope_factor is None:
+ mscale = 1.0
+ else:
+ mscale = 0.1 * math.log(model_args.rope_factor) + 1.0
+ self.softmax_scale = scale * mscale * mscale
+
+ self.profile_logs = get_profile_log_tensor()
+
+ def init_reference_weights(
+ self, state_dict: dict[str, torch.Tensor], device_id: int = 0
+ ) -> None:
+ del state_dict, device_id
+ self.is_ref_weights_init = True
+
+ def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ del state_dict
+ self.is_tilert_weights_init = True
+
+ def init_random_weights(self) -> None:
+ self.is_ref_weights_init = True
+ self.is_tilert_weights_init = True
+
+ def init_tilert_vars(self, batch_size: int, seq_len: int) -> None:
+ del batch_size, seq_len
+ self.profile_logs = get_profile_log_tensor()
+ self.is_var_init = True
+
+ def golden_forward(
+ self,
+ q_nope: torch.Tensor,
+ q_pe: torch.Tensor,
+ kv_cache: torch.Tensor,
+ pe_cache: torch.Tensor,
+ topk_indices: torch.Tensor,
+ cur_pos: torch.Tensor,
+ ) -> torch.Tensor:
+ """Flash Sparse MLA golden version.
+
+ Args:
+ q_nope: Query tensor. (bs, seqlen, heads, dim)
+ q_pe: Query position embedding tensor. (bs, seqlen, heads, pe_dim)
+ kv_cache: Key-value tensor. (bs, seqlen_kv, dim)
+ pe_cache: Key position embedding tensor. (bs, seqlen_kv, pe_dim)
+ topk_indices: Indices tensor. (bs, seqlen, topk)
+ cur_pos: cur_pos tensor. (1)
+ """
+ batch_size = q_nope.shape[0]
+ seqlen = q_nope.shape[1]
+ seqlen_kv = kv_cache.shape[1]
+
+ start_pos = int(cur_pos.item())
+ mask = (
+ torch.full((seqlen, seqlen_kv), float("-inf")).triu_(start_pos + 1)
+ if seqlen > 1
+ else None
+ )
+
+ scores = (
+ torch.einsum("bshc,btc->bsht", q_nope.float(), kv_cache.float())
+ + torch.einsum("bshr,btr->bsht", q_pe.float(), pe_cache.float())
+ ) * self.softmax_scale
+ index_mask = torch.full(
+ (batch_size, seqlen, seqlen_kv), float("-inf"), device=q_nope.device
+ ).scatter_(-1, topk_indices, 0)
+ if mask is not None:
+ index_mask += mask
+
+ scores += index_mask.unsqueeze(2)
+ scores = scores.softmax(dim=-1, dtype=torch.float32)
+ return torch.einsum("bsht,btc->bshc", scores.to(torch.bfloat16), kv_cache)
+
+ def tilert_forward(
+ self,
+ q_nope: torch.Tensor,
+ q_pe: torch.Tensor,
+ kv_cache: torch.Tensor,
+ pe_cache: torch.Tensor,
+ topk_indices: torch.Tensor,
+ cur_pos: torch.Tensor,
+ ) -> torch.Tensor:
+ """Flash Sparse MLA tilert version.
+
+ Args:
+ q_nope: Query tensor. (bs, seqlen, heads, dim)
+ q_pe: Query position embedding tensor. (bs, seqlen, heads, pe_dim)
+ kv_cache: Key-value tensor. (bs, seqlen_kv, dim)
+ pe_cache: Key position embedding tensor. (bs, seqlen_kv, pe_dim)
+ topk_indices: Indices tensor. (bs, seqlen, topk)
+ cur_pos: cur_pos tensor. (1)
+ """
+ batch_size, seqlen, heads, dim = q_nope.shape
+ v_dim = kv_cache.shape[-1]
+
+ topk_indices = topk_indices.to(torch.int32)
+ topk_indices = topk_indices[..., : kv_cache.shape[1]]
+ device = q_nope.device
+ if any(t.device != device for t in (q_pe, kv_cache, pe_cache, topk_indices, cur_pos)):
+ raise RuntimeError(
+ "flash_sparse_mla inputs must be on the same device: "
+ f"q_nope={device}, q_pe={q_pe.device}, kv_cache={kv_cache.device}, "
+ f"pe_cache={pe_cache.device}, topk_indices={topk_indices.device}, "
+ f"cur_pos={cur_pos.device}"
+ )
+ if self.profile_logs is not None and self.profile_logs.device != device:
+ self.profile_logs = get_profile_log_tensor(device_index=device.index, device=device)
+ output = torch.zeros(
+ (batch_size, seqlen, heads, v_dim), dtype=torch.bfloat16, device=device
+ )
+ flash_sparse_mla(
+ q_nope,
+ q_pe,
+ kv_cache,
+ pe_cache,
+ topk_indices,
+ cur_pos,
+ output,
+ self.profile_logs,
+ model_arch=self.model_args.arch_name,
+ )
+ return output
+
+ def to_tilert_weights(self) -> None:
+ raise NotImplementedError("to_tilert_weights not implemented")
+
+ def __call__(
+ self,
+ q_nope: torch.Tensor,
+ q_pe: torch.Tensor,
+ kv_cache: torch.Tensor,
+ pe_cache: torch.Tensor,
+ topk_indices: torch.Tensor,
+ cur_pos: torch.Tensor,
+ ) -> torch.Tensor:
+ if self.flag_enable_tilert:
+ return self.tilert_forward(q_nope, q_pe, kv_cache, pe_cache, topk_indices, cur_pos)
+ return self.golden_forward(q_nope, q_pe, kv_cache, pe_cache, topk_indices, cur_pos)
diff --git a/tilert/models/glm_5/_dsa_v32/ops/layernorm_rope_rotate.py b/tilert/models/glm_5/_dsa_v32/ops/layernorm_rope_rotate.py
new file mode 100644
index 0000000..4fc8c0d
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/ops/layernorm_rope_rotate.py
@@ -0,0 +1,243 @@
+"""Layernorm_rope_rotate operation module."""
+
+from dataclasses import dataclass
+from enum import Enum
+
+import torch
+import torch.nn.functional as F
+
+from tilert.models.base import TileRTModule
+from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
+from tilert.models.glm_5._dsa_v32.ops.rotate import rotate_activation
+from tilert.models.utils import apply_rotary_emb
+from tilert.utils import get_profile_log_tensor
+
+__all__ = [
+ "layernorm_rope_rotate",
+ "LayerNormRoPERotate",
+ "LayerNormRoPERotateRefWeightsAlias",
+ "LayerNormRoPERotateTilertWeightsAlias",
+]
+
+
+def layernorm_rope_rotate(
+ input_raw: torch.Tensor,
+ cur_pos: torch.Tensor,
+ k_cache_raw: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor,
+ freqs_cis: torch.Tensor,
+ profile_logs: torch.Tensor,
+ model_arch: str,
+ compute_kernel_type: str = "general",
+) -> None:
+ """
+ Layernorm_rope_rotate operation.
+
+ Layernorm_rope_rotate the input tensor `input_raw` and stores the result in `k_cache_raw`.
+
+ Args:
+ input_raw (torch.Tensor): The input tensor.
+ cur_pos (torch.Tensor): The current position tensor.
+ k_cache_raw (torch.Tensor): The output tensor where the result will be stored.
+ weight (torch.Tensor): The weight tensor.
+ bias (torch.Tensor): The bias tensor.
+ freqs_cis (torch.Tensor): The frequency tensor.
+ profile_logs (torch.Tensor): Tensor for storing profiling logs.
+
+ Returns:
+ None
+ """
+ if input_raw.dtype != torch.bfloat16:
+ raise ValueError("input must be a bfloat16 tensor.")
+ if cur_pos.dtype != torch.int32:
+ raise ValueError("cur_pos must be a int32 tensor.")
+ if k_cache_raw.dtype != torch.bfloat16:
+ raise ValueError("k_cache must be a bfloat16 tensor.")
+
+ if weight.dtype != torch.float32:
+ raise ValueError("weight must be a float32 tensor.")
+
+ if bias.dtype != torch.float32:
+ raise ValueError("bias must be a float32 tensor.")
+
+ if freqs_cis.dtype != torch.float32:
+ raise ValueError("freqs_cis must be a float32 tensor.")
+
+ batch, seq, dim = input_raw.shape
+ if dim != 128:
+ raise ValueError("dim must be 128, as we precompute scale inner kernel")
+ if batch != 1:
+ raise ValueError("batch must be 1 in this version")
+
+ torch.ops.tilert.layernorm_rope_rotate_op(
+ input_raw,
+ cur_pos,
+ k_cache_raw,
+ weight,
+ bias,
+ freqs_cis,
+ model_arch,
+ compute_kernel_type,
+ profile_logs,
+ )
+
+
+@dataclass
+class LayerNormRoPERotateRefWeightsAlias:
+ """Reference weights alias for LayerNormRoPERotate."""
+
+ k_weight = "self_attn.indexer.k_norm.weight"
+ k_bias = "self_attn.indexer.k_norm.bias"
+
+ @property
+ def ref_tensor_alias(self) -> list[str]:
+ return [self.k_weight, self.k_bias]
+
+ def __call__(self) -> list[str]:
+ return self.ref_tensor_alias
+
+
+@dataclass
+class LayerNormRoPERotateTilertWeightsAlias:
+ """TileRT weights alias for LayerNormRoPERotate."""
+
+ k_weight = "k_weights"
+ k_bias = "k_bias"
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return [self.k_weight, self.k_bias]
+
+ def __call__(self) -> list[str]:
+ return self.tilert_tensor_alias
+
+
+class LayerNormRoPERotateAlgorithm(Enum):
+ """LayerNormRoPERotate algorithm."""
+
+ GENERAL = "general"
+
+
+class LayerNormRoPERotate(TileRTModule):
+ """LayerNormRoPERotate module: LayerNorm + RoPE + rotate on K indexer output."""
+
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [LayerNormRoPERotateAlgorithm.GENERAL],
+ "glm_5": [LayerNormRoPERotateAlgorithm.GENERAL],
+ }
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ num_devices: int,
+ device_id: int,
+ ref_weights_alias: LayerNormRoPERotateRefWeightsAlias | None = None,
+ ):
+ super().__init__(
+ self.__class__.__name__,
+ model_args=model_args,
+ num_devices=num_devices,
+ device_id=device_id,
+ )
+
+ self.tilert_weights_alias = LayerNormRoPERotateTilertWeightsAlias()
+ self.ref_weights_alias = (
+ ref_weights_alias
+ if ref_weights_alias is not None
+ else LayerNormRoPERotateRefWeightsAlias()
+ )
+
+ self.rope_head_dim = self.model_args.qk_rope_head_dim
+ self.head_dim = self.model_args.index_head_dim
+
+ self.ref_weight: torch.Tensor | None = None
+ self.ref_bias: torch.Tensor | None = None
+ self.tilert_weight: torch.Tensor | None = None
+ self.tilert_bias: torch.Tensor | None = None
+ self.output: torch.Tensor | None = None
+ self.profile_logs: torch.Tensor | None = None
+
+ def get_weights_list(self) -> list[torch.Tensor]:
+ return [self.tilert_weight, self.tilert_bias]
+
+ def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ """
+ Device sharding: replicate weight and bias for each device.
+
+ Args:
+ weights_map: Map from ref weight alias to tensor.
+
+ Returns:
+ Map from tilert weight alias to (num_devices, ...) tensors.
+ """
+ k_weight = weights_map[self.ref_weights_alias.k_weight][None, ...].repeat(
+ self.num_devices, 1
+ )
+ k_bias = weights_map[self.ref_weights_alias.k_bias][None, ...].repeat(self.num_devices, 1)
+ return {
+ self.tilert_weights_alias.k_weight: k_weight,
+ self.tilert_weights_alias.k_bias: k_bias,
+ }
+
+ def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ self.ref_weight = state_dict[self.ref_weights_alias.k_weight].contiguous().float()
+ self.ref_bias = state_dict[self.ref_weights_alias.k_bias].contiguous().float()
+
+ def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ self.tilert_weight = state_dict[self.tilert_weights_alias.k_weight].contiguous().float()
+ self.tilert_bias = state_dict[self.tilert_weights_alias.k_bias].contiguous().float()
+
+ def init_random_weights(self) -> None:
+ ref_weight = torch.ones(self.head_dim, dtype=torch.float32)
+ ref_bias = torch.zeros(self.head_dim, dtype=torch.float32)
+ ref_state_dict = dict(zip(self.ref_weights_alias(), [ref_weight, ref_bias]))
+ self.init_reference_weights(ref_state_dict)
+ self.init_tilert_weights(
+ {_k: _v[self.device_id] for _k, _v in self.device_sharding(ref_state_dict).items()}
+ )
+
+ def init_tilert_vars(self, batch_size: int, seq_len: int) -> None:
+ self.cur_pos = torch.tensor([0], dtype=torch.int32)
+ self.output = torch.zeros((batch_size, seq_len, self.head_dim), dtype=torch.bfloat16)
+ self.profile_logs = get_profile_log_tensor()
+ self.is_var_init = True
+
+ def golden_forward(self, idx_k: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
+ assert self.ref_weight is not None and self.ref_bias is not None
+ k = F.layer_norm(
+ idx_k.float(),
+ (self.head_dim,),
+ self.ref_weight,
+ self.ref_bias,
+ 1e-6,
+ ).to(idx_k.dtype)
+ k_pe, k_nope = torch.split(
+ k, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
+ )
+ k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis).squeeze(2)
+ k = torch.cat([k_pe, k_nope], dim=-1)
+ return rotate_activation(k)
+
+ def tilert_forward(self, idx_k: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
+ assert self.tilert_weight is not None and self.tilert_bias is not None
+ assert self.output is not None and self.profile_logs is not None
+ rope_freqs = (
+ torch.view_as_real(freqs_cis).reshape(*freqs_cis.shape[:-1], -1).float().unsqueeze(1)
+ )
+ layernorm_rope_rotate(
+ idx_k,
+ self.cur_pos,
+ self.output,
+ self.tilert_weight,
+ self.tilert_bias,
+ rope_freqs,
+ self.profile_logs,
+ model_arch=self.model_args.arch_name,
+ )
+ return self.output
+
+ def __call__(self, idx_k: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
+ if self.flag_enable_tilert:
+ return self.tilert_forward(idx_k, freqs_cis)
+ return self.golden_forward(idx_k, freqs_cis)
diff --git a/tilert/models/glm_5/_dsa_v32/ops/padded_allreduce_add.py b/tilert/models/glm_5/_dsa_v32/ops/padded_allreduce_add.py
new file mode 100644
index 0000000..a6490c9
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/ops/padded_allreduce_add.py
@@ -0,0 +1,147 @@
+"""PaddedAllReduceAdd operation module."""
+
+from enum import Enum
+
+import torch
+
+from tilert.models.base import TileRTModule
+from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
+from tilert.utils import get_profile_log_tensor
+
+__all__ = [
+ "padded_allreduce_add",
+ "PaddedAllReduceAdd",
+]
+
+
+def padded_allreduce_add(
+ partial_buf: torch.Tensor,
+ x_in: torch.Tensor,
+ flag: int,
+ vec_out: torch.Tensor,
+ profile_logs: torch.Tensor,
+ model_arch: str,
+ compute_kernel_type: str = "bf16",
+) -> None:
+ """Padded AllReduce + residual add for Device Group A (GPU 0).
+
+ GPU 0 contributes zeros to the 8-GPU AllReduce, then adds the residual.
+
+ Args:
+ partial_buf: Zero-filled partial buffer [1, L, hidden_dim] bf16.
+ x_in: Residual input [1, L, hidden_dim] bf16.
+ flag: AllReduce sync flag.
+ vec_out: Output tensor [1, L, hidden_dim] bf16.
+ profile_logs: Profile logs tensor.
+ model_arch: Model architecture ("deepseek_v3_2" or "glm_5").
+ compute_kernel_type: Compute kernel type ("bf16").
+ """
+ torch.ops.tilert.padded_allreduce_add_op(
+ partial_buf, x_in, flag, vec_out, profile_logs, model_arch, compute_kernel_type
+ )
+
+
+class PaddedAllReduceAddAlgorithm(Enum):
+ """PaddedAllReduceAdd algorithm."""
+
+ BF16 = "bf16"
+
+
+class PaddedAllReduceAdd(TileRTModule):
+ """PaddedAllReduceAdd module — zero-partial AllReduce + residual add."""
+
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [PaddedAllReduceAddAlgorithm.BF16],
+ "glm_5": [PaddedAllReduceAddAlgorithm.BF16],
+ }
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ num_devices: int,
+ device_id: int = 0,
+ ):
+ super().__init__(
+ self.__class__.__name__,
+ model_args=model_args,
+ num_devices=num_devices,
+ device_id=device_id,
+ )
+
+ self.dim = self.model_args.dim
+
+ self.partial_buf: torch.Tensor | None = None
+
+ self.hidden_out: torch.Tensor | None = None
+
+ self.profile_logs: torch.Tensor | None = None
+ self.is_var_init = False
+
+ def init_tilert_vars(self, batch_size: int, seq_len: int) -> None:
+ """Allocate output buffer and persistent zero-filled partial buffer.
+
+ Args:
+ batch_size: Batch size.
+ seq_len: Sequence length.
+ """
+ self.hidden_out = torch.zeros(
+ (batch_size, seq_len, self.dim),
+ dtype=torch.bfloat16,
+ device=f"cuda:{self.device_id}",
+ )
+ self.partial_buf = torch.zeros(
+ (batch_size, seq_len, self.dim),
+ dtype=torch.bfloat16,
+ device=f"cuda:{self.device_id}",
+ )
+ self.profile_logs = get_profile_log_tensor(device=f"cuda:{self.device_id}")
+ self.is_var_init = True
+
+ def golden_forward(
+ self,
+ x_in: torch.Tensor,
+ ) -> torch.Tensor:
+ """Golden reference: allreduce(zeros) + x_in = x_in (single-GPU).
+
+ On a single GPU, allreduce of zeros returns zeros, so output = x_in.
+
+ Args:
+ x_in: Residual input [1, L, hidden_dim].
+
+ Returns:
+ Output tensor (copy of x_in).
+ """
+ return x_in.clone()
+
+ def tilert_forward(
+ self,
+ x_in: torch.Tensor,
+ flag: int,
+ ) -> torch.Tensor:
+ """Run TileRT kernel forward.
+
+ Args:
+ x_in: Residual input [1, L, hidden_dim].
+ flag: AllReduce sync flag.
+
+ Returns:
+ Output tensor [1, L, hidden_dim].
+ """
+ assert self.hidden_out is not None
+ assert self.partial_buf is not None
+ assert self.profile_logs is not None
+ padded_allreduce_add(
+ self.partial_buf,
+ x_in,
+ flag,
+ self.hidden_out,
+ self.profile_logs,
+ model_arch=self.model_args.arch_name,
+ )
+ return self.hidden_out
+
+ def __call__(
+ self,
+ x_in: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.golden_forward(x_in)
diff --git a/tilert/models/glm_5/_dsa_v32/ops/projo_wkvb.py b/tilert/models/glm_5/_dsa_v32/ops/projo_wkvb.py
new file mode 100644
index 0000000..3e99f0e
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/ops/projo_wkvb.py
@@ -0,0 +1,483 @@
+"""ProjOWkvb operation module."""
+
+import math
+from dataclasses import dataclass
+from enum import Enum
+
+import torch
+
+from tilert.models.base import TileRTModule, TilertWeightsConverter
+from tilert.models.common import init_func, weight_dequant
+from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
+from tilert.utils import get_profile_log_tensor
+
+__all__ = [
+ "projo_wkvb",
+ "ProjoWKVb",
+ "ProjoWKVbAlgorithm",
+ "ProjoWKVbWeightsConverter",
+ "ProjoWKVbRefWeightsAlias",
+ "ProjoWKVbTilertWeightsAlias",
+]
+
+
+def projo_wkvb(
+ o_in: torch.Tensor,
+ wkv_b_b: torch.Tensor,
+ wkv_b_scales: torch.Tensor,
+ output: torch.Tensor,
+ profile_logs: torch.Tensor,
+ model_arch: str,
+ compute_kernel_type: str = "fp16mma",
+) -> None:
+ """
+ Define the ProjOWkvb operation.
+
+ Args:
+ o_in: Input tensor.
+ wkv_b_b: Weight tensor.
+ wkv_b_scales: Scale tensor.
+ output: Output tensor.
+ profile_logs: Profile logs tensor.
+ model_arch: Model architecture ("deepseek_v3_2" or "glm_5").
+ compute_kernel_type: Kernel type ("fp16mma" for both DSv32 and GLM5).
+ """
+ torch.ops.tilert.projo_wkvb_op(
+ o_in,
+ wkv_b_b,
+ wkv_b_scales,
+ output,
+ model_arch,
+ compute_kernel_type,
+ profile_logs,
+ torch.empty(0, dtype=torch.int64, device=o_in.device),
+ )
+
+
+class ProjoWKVbAlgorithm(Enum):
+ """ProjoWKVb algorithm"""
+
+ GENERAL = "general"
+ FP16MMA = "fp16mma"
+ BF16MMA = "bf16mma"
+
+
+class ProjoWKVbWeightsConverter(TilertWeightsConverter):
+ def __init__(self, model_args: ModelArgs, num_devices: int):
+ super().__init__(model_args, num_devices)
+
+ @staticmethod
+ def _swizzle_mma_16x16(mat_in: torch.Tensor) -> torch.Tensor:
+ """Swizzle a [*, 16, 16] block for the packed weight layout."""
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 16
+ pre_shape = mat_in.shape[:-2]
+ mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 2).transpose(-4, -3).transpose(-5, -4)
+ return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 2).transpose(-3, -2)
+
+ @staticmethod
+ def _swizzle_mma_16x16_for_pages(mat_in: torch.Tensor, k_dim: int, pages: int) -> torch.Tensor:
+ """Swizzle a [*, 16, K] matrix for the paged weight layout."""
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == k_dim
+ pre_shape = mat_in.shape[:-2]
+ k_per_page = k_dim // pages
+ n_k_tiles = k_per_page // 16
+ mat_in = mat_in.reshape(*pre_shape, 16, pages, k_per_page).transpose(-3, -2)
+ mat_in = mat_in.reshape(*pre_shape, pages, 16, n_k_tiles, 16).transpose(-3, -2)
+ mat_in = ProjoWKVbWeightsConverter._swizzle_mma_16x16(mat_in)
+ return mat_in.contiguous()
+
+ def convert_to_fp16mma(self, weights: list[torch.Tensor]) -> torch.Tensor:
+ """Convert weights to the packed format expected by the kernel."""
+ with torch.inference_mode():
+ wkv_b_b, wkv_b_b_scales = self.convert_to_general(weights)
+
+ n_heads = wkv_b_b.size(0)
+ v_head_dim = wkv_b_b.size(1)
+ kv_lora_rank = wkv_b_b.size(2)
+ num_ctas = 80
+ rows_per_cta = (n_heads * v_head_dim) // num_ctas
+
+ is_glm5 = self.model_args.arch_name == "glm_5"
+
+ w_flat = wkv_b_b.reshape(num_ctas, rows_per_cta // 16, 16, kv_lora_rank)
+ w_swizzled = ProjoWKVbWeightsConverter._swizzle_mma_16x16_for_pages(
+ w_flat, kv_lora_rank, pages=1
+ )
+ w_bytes = w_swizzled.reshape(num_ctas, -1)
+
+ scale_k_block = 128
+ n_scale_k = kv_lora_rank // scale_k_block
+ ctas_per_head = num_ctas // n_heads
+
+ if is_glm5:
+ ctas_per_scale_row = 64 // rows_per_cta
+ scales_per_cta = wkv_b_b_scales.repeat_interleave(ctas_per_scale_row, dim=1)
+ scales_per_cta = scales_per_cta.reshape(num_ctas, n_scale_k)
+ else:
+ scales_per_cta = wkv_b_b_scales.squeeze(1).repeat_interleave(ctas_per_head, dim=0)
+
+ scale_dtype = torch.float32
+ scales_per_cta = scales_per_cta.to(scale_dtype)
+
+ mat_bytes = rows_per_cta * kv_lora_rank
+ scale_bytes = n_scale_k * 4
+ page_size = (mat_bytes + scale_bytes + 127) // 128 * 128
+
+ scales_raw = scales_per_cta.contiguous().view(torch.float8_e4m3fn)
+ padding_size = page_size - mat_bytes - scales_raw.shape[-1]
+ padding = torch.zeros(
+ num_ctas, padding_size, dtype=torch.float8_e4m3fn, device=wkv_b_b.device
+ )
+ return torch.cat([w_bytes, scales_raw, padding], dim=-1).contiguous()
+
+ def convert_to_bf16mma(self, weights: list[torch.Tensor]) -> torch.Tensor:
+ """Convert weights to the packed format expected by the BF16 kernel."""
+ with torch.inference_mode():
+ tilert_wkv_b_weights, tilert_wkv_b_scales = weights
+
+ wkvb_head_dim = self.model_args.qk_nope_head_dim + self.model_args.v_head_dim
+ left_head_dim = wkvb_head_dim % self.model_args.block_size
+ hd_block = left_head_dim if left_head_dim != 0 else self.model_args.block_size
+
+ if self.model_args.n_heads % self.num_devices == 0:
+ n_local_heads = self.model_args.n_heads // self.num_devices
+ else:
+ n_local_heads = math.ceil(self.model_args.n_heads / self.num_devices)
+ if n_local_heads % 2 != 0:
+ n_local_heads += 1
+
+ v_head_dim = self.model_args.v_head_dim
+ kv_lora_rank = self.model_args.kv_lora_rank
+ n_block = self.model_args.block_size
+
+ w = tilert_wkv_b_weights
+ s = tilert_wkv_b_scales
+ if self.model_args.n_heads % self.num_devices != 0:
+ n_current = w.size(0)
+ if n_current < n_local_heads:
+ pad_w = torch.zeros(
+ n_local_heads - n_current, *w.shape[1:], dtype=w.dtype, device=w.device
+ )
+ w = torch.cat([w, pad_w], dim=0)
+ pad_s = torch.zeros(
+ n_local_heads - n_current, *s.shape[1:], dtype=s.dtype, device=s.device
+ )
+ s = torch.cat([s, pad_s], dim=0)
+
+ s = s.float()
+ s = s.repeat_interleave(hd_block, dim=1).repeat_interleave(n_block, dim=2)
+ wkv_bf16 = (w.float() * s).to(torch.bfloat16)
+ n_heads = n_local_heads
+
+ num_ctas = 80
+ rows_per_cta = (n_heads * v_head_dim) // num_ctas
+
+ w_flat = wkv_bf16.reshape(num_ctas, rows_per_cta // 16, 16, kv_lora_rank)
+ w_swizzled = ProjoWKVbWeightsConverter._swizzle_mma_16x16_for_pages(
+ w_flat, kv_lora_rank, pages=1
+ )
+ w_bytes = w_swizzled.reshape(num_ctas, -1).contiguous().view(torch.float8_e4m3fn)
+
+ mat_bytes = rows_per_cta * kv_lora_rank * 2
+ page_size = (mat_bytes + 127) // 128 * 128
+ padding_size = page_size - w_bytes.shape[-1]
+
+ if padding_size > 0:
+ padding = torch.zeros(
+ num_ctas, padding_size, dtype=torch.float8_e4m3fn, device=wkv_bf16.device
+ )
+ return torch.cat([w_bytes, padding], dim=-1).contiguous()
+ return w_bytes.contiguous()
+
+ def convert_to_general(self, weights: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
+ with torch.inference_mode():
+ tilert_wkv_b_weights, tilert_wkv_b_scales = weights
+
+ wkv_b_b = tilert_wkv_b_weights.contiguous()
+ wkv_b_b_scales = tilert_wkv_b_scales.contiguous()
+ if self.model_args.arch_name == "glm_5":
+ if wkv_b_b_scales.dtype != torch.float32:
+ print(
+ "Warning: ProjoWKVbWeightsConverter: "
+ + f"wkv_b_b_scales.dtype: {wkv_b_b_scales.dtype} "
+ + "is not float32, convert to float32."
+ )
+ wkv_b_b_scales = wkv_b_b_scales.to(torch.float32)
+ else:
+ wkv_b_b_scales = wkv_b_b_scales.to(torch.bfloat16)
+
+ wkv_b_b = wkv_b_b.detach()
+ wkv_b_b_scales = wkv_b_b_scales.detach()
+
+ if self.model_args.n_heads % self.num_devices != 0:
+ n_target = math.ceil(self.model_args.n_heads / self.num_devices)
+ if n_target % 2 != 0:
+ n_target += 1
+ n_current = wkv_b_b.size(0)
+ if n_current < n_target:
+ pad_b = torch.zeros(
+ n_target - n_current,
+ *wkv_b_b.shape[1:],
+ dtype=wkv_b_b.dtype,
+ device=wkv_b_b.device,
+ )
+ wkv_b_b = torch.cat([wkv_b_b, pad_b], dim=0)
+ pad_s = torch.zeros(
+ n_target - n_current,
+ *wkv_b_b_scales.shape[1:],
+ dtype=wkv_b_b_scales.dtype,
+ device=wkv_b_b_scales.device,
+ )
+ wkv_b_b_scales = torch.cat([wkv_b_b_scales, pad_s], dim=0)
+ wkv_b_b = wkv_b_b.contiguous()
+ wkv_b_b_scales = wkv_b_b_scales.contiguous()
+
+ return wkv_b_b, wkv_b_b_scales
+
+
+@dataclass
+class ProjoWKVbRefWeightsAlias:
+ """Reference weights alias for ProjoWKVb."""
+
+ wkv_b_weights = "self_attn.kv_b_proj.weight"
+ wkv_b_scales = "self_attn.kv_b_proj.weight_scale_inv"
+
+ @property
+ def ref_tensor_alias(self) -> list[str]:
+ return [self.wkv_b_weights, self.wkv_b_scales]
+
+ def __call__(self) -> list[str]:
+ return self.ref_tensor_alias
+
+
+@dataclass
+class ProjoWKVbTilertWeightsAlias:
+ """TileRT weights alias for ProjoWKVb."""
+
+ wkv_b_weights = "wkv_b2_weights"
+ wkv_b_scales = "wkv_b2_scales"
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return [self.wkv_b_weights, self.wkv_b_scales]
+
+ def __call__(self) -> list[str]:
+ return self.tilert_tensor_alias
+
+
+class ProjoWKVb(TileRTModule):
+ """ProjoWKVb module: O projection (wkv_b) for output."""
+
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [ProjoWKVbAlgorithm.FP16MMA],
+ "glm_5": [ProjoWKVbAlgorithm.FP16MMA],
+ }
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ num_devices: int,
+ device_id: int = 0,
+ ref_weights_alias: ProjoWKVbRefWeightsAlias | None = None,
+ ):
+ super().__init__(
+ self.__class__.__name__,
+ model_args=model_args,
+ num_devices=num_devices,
+ device_id=device_id,
+ )
+
+ self.tilert_weights_alias = ProjoWKVbTilertWeightsAlias()
+ self.ref_weights_alias = (
+ ref_weights_alias if ref_weights_alias is not None else ProjoWKVbRefWeightsAlias()
+ )
+
+ self.ref_wkv_b: torch.Tensor | None = None
+ self.tilert_wkv_b_b: torch.Tensor | None = None
+ self.tilert_wkv_b_b_scales: torch.Tensor | None = None
+ self.output: torch.Tensor | None = None
+ self.profile_logs: torch.Tensor | None = None
+
+ if self.model_args.n_heads % self.num_devices == 0:
+ self.num_local_heads = self.model_args.n_heads // self.num_devices
+ else:
+ n_local = math.ceil(self.model_args.n_heads / self.num_devices)
+ if n_local % 2 != 0:
+ n_local += 1
+ self.num_local_heads = n_local
+
+ self.wkvb_lora_rank = self.model_args.kv_lora_rank
+ self.wkvb_lora_rank_qsize = self.wkvb_lora_rank // self.model_args.block_size
+
+ self.wkvb_head_dim = self.model_args.qk_nope_head_dim + self.model_args.v_head_dim
+ self.wkvb_v_head_dim = self.model_args.v_head_dim
+ left_head_dim = self.wkvb_head_dim % self.model_args.block_size
+ if left_head_dim != 0:
+ assert self.model_args.block_size % left_head_dim == 0
+ self.head_dim_block_size = left_head_dim
+ self.head_dim_scale_repeat = self.model_args.block_size // self.head_dim_block_size
+ else:
+ self.head_dim_scale_repeat = 1
+ self.head_dim_block_size = self.model_args.block_size
+ self.wkvb_head_qsize = self.wkvb_head_dim // self.head_dim_block_size
+ self.wkvb_v_head_qsize = self.wkvb_v_head_dim // self.head_dim_block_size
+
+ self.compute_kernel_type = "fp16mma"
+
+ def get_weights_list(self) -> list[torch.Tensor]:
+ return [self.tilert_wkv_b_b, self.tilert_wkv_b_b_scales]
+
+ def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ """
+ Device sharding: split weights and scales per device.
+
+ Args:
+ weights_map: Map from ref weight alias to tensor.
+
+ Returns:
+ Map from tilert weight alias to (num_devices, ...) tensors.
+ """
+ kv_b_proj_weight = weights_map[self.ref_weights_alias.wkv_b_weights]
+ kv_b_proj_weight_scale = weights_map[self.ref_weights_alias.wkv_b_scales]
+
+ if self.model_args.n_heads % self.num_devices == 0:
+ dev_weights = kv_b_proj_weight.view(
+ self.num_devices, self.num_local_heads, self.wkvb_head_dim, self.wkvb_lora_rank
+ )
+ dev_scale_rows = self.num_local_heads * self.wkvb_head_dim // self.model_args.block_size
+ dev_scales = kv_b_proj_weight_scale.view(
+ self.num_devices, dev_scale_rows, 1, self.wkvb_lora_rank_qsize
+ )
+ else:
+ from tilert.models.glm_5._dsa_v32.ops.rmsnorm_projq_wqb import (
+ RmsnormProjqWqbWeightsConverter,
+ )
+
+ wq_b_list, scale_list = RmsnormProjqWqbWeightsConverter._redistribute_heads(
+ kv_b_proj_weight,
+ kv_b_proj_weight_scale,
+ n_total_heads=self.model_args.n_heads,
+ n_local_heads=self.num_local_heads,
+ num_devices=self.num_devices,
+ qk_head_dim=self.wkvb_head_dim,
+ block_size=self.model_args.block_size,
+ )
+ dev_weights = torch.stack(wq_b_list, dim=0).view(
+ self.num_devices, self.num_local_heads, self.wkvb_head_dim, self.wkvb_lora_rank
+ )
+ dev_scale_rows = self.num_local_heads * self.wkvb_head_dim // self.model_args.block_size
+ dev_scales = torch.stack(scale_list, dim=0).view(
+ self.num_devices, dev_scale_rows, 1, self.wkvb_lora_rank_qsize
+ )
+
+ wkvb = dev_weights[:, :, -self.wkvb_v_head_dim :]
+ wkvb_scales = (
+ dev_scales.contiguous()
+ .repeat(1, 1, self.head_dim_scale_repeat, 1)
+ .view(
+ self.num_devices,
+ self.num_local_heads,
+ self.wkvb_head_qsize,
+ self.wkvb_lora_rank_qsize,
+ )
+ .contiguous()[:, :, -self.wkvb_v_head_qsize :]
+ )
+ return {
+ self.tilert_weights_alias.wkv_b_weights: wkvb.contiguous(),
+ self.tilert_weights_alias.wkv_b_scales: wkvb_scales.contiguous(),
+ }
+
+ def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ sharding_size = self.num_local_heads * self.wkvb_head_dim
+ sharding_start = self.device_id * sharding_size
+ sharding_end = sharding_start + sharding_size
+ wkv_b = weight_dequant(
+ state_dict[self.ref_weights_alias.wkv_b_weights],
+ state_dict[self.ref_weights_alias.wkv_b_scales],
+ )
+ wkv_b = wkv_b[sharding_start:sharding_end, :]
+ wkv_b = wkv_b.view(self.num_local_heads, self.wkvb_head_dim, self.wkvb_lora_rank)
+ self.ref_wkv_b = wkv_b[:, -self.wkvb_v_head_dim :]
+
+ def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ self.init_tilert_weights_hmma(state_dict)
+
+ def init_tilert_weights_hmma(self, state_dict: dict[str, torch.Tensor]) -> None:
+ """Initialize with HMMA-packed weights."""
+ packed = ProjoWKVbWeightsConverter(self.model_args, self.num_devices).dispatch(
+ ProjoWKVbAlgorithm.FP16MMA,
+ [
+ state_dict[self.tilert_weights_alias.wkv_b_weights],
+ state_dict[self.tilert_weights_alias.wkv_b_scales],
+ ],
+ )
+ self.tilert_wkv_b_b = packed
+ self.tilert_wkv_b_b_scales = torch.empty(1, dtype=torch.float8_e4m3fn, device=packed.device)
+ self.compute_kernel_type = "fp16mma"
+
+ def init_tilert_weights_hmma_bf16(self, state_dict: dict[str, torch.Tensor]) -> None:
+ """Initialize with BF16 HMMA-packed weights (dequantized, no scales)."""
+ packed = ProjoWKVbWeightsConverter(self.model_args, self.num_devices).dispatch(
+ ProjoWKVbAlgorithm.BF16MMA,
+ [
+ state_dict[self.tilert_weights_alias.wkv_b_weights],
+ state_dict[self.tilert_weights_alias.wkv_b_scales],
+ ],
+ )
+ self.tilert_wkv_b_b = packed
+ self.tilert_wkv_b_b_scales = torch.empty(1, dtype=torch.float8_e4m3fn, device=packed.device)
+ self.compute_kernel_type = "bf16mma"
+
+ def init_random_weights(self) -> None:
+ padded_total_heads = self.num_local_heads * self.num_devices
+ wkv_b = init_func(
+ torch.empty(
+ padded_total_heads * self.wkvb_head_dim,
+ self.wkvb_lora_rank,
+ dtype=torch.float8_e4m3fn,
+ )
+ )
+ wkv_b_scales = init_func(
+ torch.empty(
+ padded_total_heads * self.wkvb_head_dim // self.model_args.block_size,
+ self.wkvb_lora_rank_qsize,
+ dtype=torch.float32,
+ )
+ )
+ ref_state_dict = dict(
+ zip(
+ self.ref_weights_alias(),
+ [wkv_b, wkv_b_scales],
+ )
+ )
+ self.init_reference_weights(ref_state_dict)
+ sharded = self.device_sharding(ref_state_dict)
+ self.init_tilert_weights({k: v[self.device_id] for k, v in sharded.items()})
+
+ def init_tilert_vars(self, batch_size: int, seq_len: int) -> None:
+ self.output = torch.zeros(
+ (batch_size, seq_len, self.num_local_heads, self.wkvb_v_head_dim),
+ dtype=torch.bfloat16,
+ )
+ self.profile_logs = get_profile_log_tensor()
+ self.is_var_init = True
+
+ def golden_forward(self, x_out: torch.Tensor) -> torch.Tensor:
+ assert self.ref_wkv_b is not None
+ return torch.einsum("bshc,hdc->bshd", x_out, self.ref_wkv_b)
+
+ def tilert_forward(self, x_out: torch.Tensor) -> torch.Tensor:
+ assert self.tilert_wkv_b_b is not None
+ assert self.tilert_wkv_b_b_scales is not None
+ assert self.output is not None
+ assert self.profile_logs is not None
+ projo_wkvb(
+ x_out,
+ self.tilert_wkv_b_b,
+ self.tilert_wkv_b_b_scales,
+ self.output,
+ self.profile_logs,
+ model_arch=self.model_args.arch_name,
+ compute_kernel_type=self.compute_kernel_type,
+ )
+ return self.output
diff --git a/tilert/models/glm_5/_dsa_v32/ops/projq_wqb.py b/tilert/models/glm_5/_dsa_v32/ops/projq_wqb.py
new file mode 100644
index 0000000..c40ca51
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/ops/projq_wqb.py
@@ -0,0 +1,466 @@
+"""ProjQB operation module."""
+
+import math
+from dataclasses import dataclass
+from enum import Enum
+
+import torch
+
+from tilert.models.base import TileRTModule, TilertWeightsConverter
+from tilert.models.common import init_func, weight_dequant
+from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
+from tilert.utils import get_profile_log_tensor
+
+__all__ = [
+ "projq_wqb",
+ "ProjqWqb",
+ "ProjqWqbAlgorithm",
+ "ProjqWqbWeightsConverter",
+ "ProjqWqbRefWeightsAlias",
+ "ProjqWqbTilertWeightsAlias",
+]
+
+
+def projq_wqb(
+ q_nope_in: torch.Tensor,
+ wkv_b_a: torch.Tensor,
+ wkv_b_a_scales: torch.Tensor,
+ output: torch.Tensor,
+ profile_logs: torch.Tensor,
+ compute_kernel_type: str = "fp16mma",
+ *,
+ model_arch: str,
+) -> None:
+ """
+ Define the ProjqWqb operation.
+
+ Args:
+ q_nope_in: Input tensor.
+ wkv_b_a: Weight tensor.
+ wkv_b_a_scales: Scale tensor.
+ output: Output tensor.
+ profile_logs: Profile logs tensor.
+ compute_kernel_type: Kernel type ("fp16mma").
+ model_arch: Model architecture ("deepseek_v3_2" or "glm_5").
+ """
+ torch.ops.tilert.projq_wqb_op(
+ q_nope_in,
+ wkv_b_a,
+ wkv_b_a_scales,
+ output,
+ model_arch,
+ compute_kernel_type,
+ profile_logs,
+ )
+
+
+class ProjqWqbAlgorithm(Enum):
+ """ProjqWqb algorithm"""
+
+ GENERAL = "general"
+ FP16MMA = "fp16mma"
+ BF16MMA = "bf16mma"
+
+
+class ProjqWqbWeightsConverter(TilertWeightsConverter):
+ def __init__(self, model_args: ModelArgs, num_devices: int, head_dim_block_size: int):
+ super().__init__(model_args, num_devices)
+ self.head_dim_block_size = head_dim_block_size
+ self.impl_block_size = 64
+
+ @staticmethod
+ def _swizzle_mma_16x16(mat_in: torch.Tensor) -> torch.Tensor:
+ """Swizzle a [*, 16, 16] block for the packed weight layout."""
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 16
+ pre_shape = mat_in.shape[:-2]
+ mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 2).transpose(-4, -3).transpose(-5, -4)
+ return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 2).transpose(-3, -2)
+
+ @staticmethod
+ def _swizzle_mma_16x16_for_pages(mat_in: torch.Tensor, k_dim: int, pages: int) -> torch.Tensor:
+ """Swizzle a [*, 16, K] matrix for the paged weight layout."""
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == k_dim
+ pre_shape = mat_in.shape[:-2]
+ k_per_page = k_dim // pages
+ n_k_tiles = k_per_page // 16
+ mat_in = mat_in.reshape(*pre_shape, 16, pages, k_per_page).transpose(-3, -2)
+ mat_in = mat_in.reshape(*pre_shape, pages, 16, n_k_tiles, 16).transpose(-3, -2)
+ mat_in = ProjqWqbWeightsConverter._swizzle_mma_16x16(mat_in)
+ return mat_in.contiguous()
+
+ def convert_to_fp16mma(self, weights: list[torch.Tensor]) -> torch.Tensor:
+ """Convert weights to the packed format expected by the kernel."""
+ with torch.inference_mode():
+ wkv_b_a, wkv_b_a_scales = self.convert_to_general(weights)
+
+ n_heads = wkv_b_a.size(0)
+ head_dim = wkv_b_a.size(2)
+ kv_lora_rank = wkv_b_a.size(1)
+ num_ctas = 80
+ rows_per_cta = (n_heads * kv_lora_rank) // num_ctas
+
+ is_glm5 = self.model_args.arch_name == "glm_5"
+
+ w_flat = wkv_b_a.reshape(num_ctas, rows_per_cta // 16, 16, head_dim)
+ w_swizzled = self._swizzle_mma_16x16_for_pages(w_flat, head_dim, pages=1)
+ w_bytes = w_swizzled.reshape(num_ctas, -1)
+
+ kScalesPerPage = head_dim // 64
+
+ if is_glm5:
+ ctas_per_scale_row = 128 // rows_per_cta
+ scales_expanded = wkv_b_a_scales.repeat_interleave(ctas_per_scale_row, dim=1)
+ scales_per_cta = scales_expanded.reshape(num_ctas, kScalesPerPage)
+ scale_dtype = torch.float32
+ else:
+ scales_per_cta = wkv_b_a_scales.reshape(num_ctas, kScalesPerPage)
+ scale_dtype = torch.bfloat16
+
+ mat_bytes = rows_per_cta * head_dim
+ scale_elem_bytes = 4 if scale_dtype == torch.float32 else 2
+ scale_bytes = kScalesPerPage * scale_elem_bytes
+ page_size = (mat_bytes + scale_bytes + 127) // 128 * 128
+
+ scales_raw = scales_per_cta.to(scale_dtype).contiguous().view(torch.float8_e4m3fn)
+ padding_size = page_size - mat_bytes - scales_raw.shape[-1]
+ padding = torch.zeros(
+ num_ctas, padding_size, dtype=torch.float8_e4m3fn, device=wkv_b_a.device
+ )
+ return torch.cat([w_bytes, scales_raw, padding], dim=-1).contiguous()
+
+ def convert_to_bf16mma(self, weights: list[torch.Tensor]) -> torch.Tensor:
+ """Convert weights to the packed format expected by the BF16 kernel."""
+ with torch.inference_mode():
+ tilert_wkv_b_weights, tilert_wkv_b_scales = weights
+
+ if self.model_args.n_heads % self.num_devices == 0:
+ n_local_heads = self.model_args.n_heads // self.num_devices
+ else:
+ n_local_heads = math.ceil(self.model_args.n_heads / self.num_devices)
+ if n_local_heads % 2 != 0:
+ n_local_heads += 1
+
+ nope_head_dim = self.model_args.qk_nope_head_dim
+ kv_lora_rank = self.model_args.kv_lora_rank
+ hd_block = self.head_dim_block_size
+ n_block = self.model_args.block_size
+
+ s = tilert_wkv_b_scales.float()
+ s = s.repeat_interleave(hd_block, dim=1).repeat_interleave(n_block, dim=2)
+ wkv_bf16 = (
+ (tilert_wkv_b_weights.float() * s).transpose(1, 2).contiguous().to(torch.bfloat16)
+ )
+ n_heads = n_local_heads
+ head_dim = nope_head_dim
+
+ num_ctas = 80
+ rows_per_cta = (n_heads * kv_lora_rank) // num_ctas
+
+ w_flat = wkv_bf16.reshape(num_ctas, rows_per_cta // 16, 16, head_dim)
+ w_swizzled = self._swizzle_mma_16x16_for_pages(w_flat, head_dim, pages=1)
+ w_bytes = w_swizzled.reshape(num_ctas, -1).contiguous().view(torch.float8_e4m3fn)
+
+ mat_bytes = rows_per_cta * head_dim * 2
+ page_size = (mat_bytes + 127) // 128 * 128
+ padding_size = page_size - w_bytes.shape[-1]
+
+ if padding_size > 0:
+ padding = torch.zeros(
+ num_ctas, padding_size, dtype=torch.float8_e4m3fn, device=wkv_bf16.device
+ )
+ return torch.cat([w_bytes, padding], dim=-1).contiguous()
+ return w_bytes.contiguous()
+
+ def convert_to_general(self, weights: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
+ with torch.inference_mode():
+ tilert_wkv_b_weights, tilert_wkv_b_scales = weights
+
+ if self.model_args.n_heads % self.num_devices == 0:
+ n_local_heads = self.model_args.n_heads // self.num_devices
+ else:
+ n_local_heads = math.ceil(self.model_args.n_heads / self.num_devices)
+ if n_local_heads % 2 != 0:
+ n_local_heads += 1
+
+ wkv_b = tilert_wkv_b_weights
+ wkv_b_scales_raw = tilert_wkv_b_scales
+ wkv_b = wkv_b.view(n_local_heads, -1, self.model_args.kv_lora_rank)
+ assert self.model_args.kv_lora_rank % self.model_args.block_size == 0
+ wkv_b_scales_raw = wkv_b_scales_raw.view(
+ n_local_heads, -1, self.model_args.kv_lora_rank // self.model_args.block_size
+ )
+ wkv_b_a = wkv_b[:, : self.model_args.qk_nope_head_dim].transpose(1, 2).contiguous()
+ assert self.model_args.qk_nope_head_dim % self.head_dim_block_size == 0
+ wkv_b_a_scales = (
+ wkv_b_scales_raw[:, : self.model_args.qk_nope_head_dim // self.head_dim_block_size]
+ .transpose(1, 2)
+ .contiguous()
+ )
+ if self.model_args.arch_name == "glm_5":
+ if wkv_b_a_scales.dtype != torch.float32:
+ print(
+ "Warning: ProjqWqbWeightsConverter: "
+ + f"wkv_b_a_scales.dtype: {wkv_b_a_scales.dtype} "
+ + "is not float32, convert to float32."
+ )
+ wkv_b_a_scales = wkv_b_a_scales.to(torch.float32)
+ else:
+ wkv_b_a_scales = wkv_b_a_scales.to(torch.bfloat16)
+ if self.head_dim_block_size != self.impl_block_size:
+ repeats = self.head_dim_block_size // self.impl_block_size
+ wkv_b_a_scales = wkv_b_a_scales.repeat(1, 1, repeats).contiguous()
+
+ wkv_b_a = wkv_b_a.detach()
+ wkv_b_a_scales = wkv_b_a_scales.detach()
+
+ return wkv_b_a, wkv_b_a_scales
+
+
+@dataclass
+class ProjqWqbRefWeightsAlias:
+ """Reference weights alias for ProjqWqb."""
+
+ wkv_b_weights = "self_attn.kv_b_proj.weight"
+ wkv_b_scales = "self_attn.kv_b_proj.weight_scale_inv"
+
+ @property
+ def ref_tensor_alias(self) -> list[str]:
+ return [self.wkv_b_weights, self.wkv_b_scales]
+
+ def __call__(self) -> list[str]:
+ return self.ref_tensor_alias
+
+
+@dataclass
+class ProjqWqbTilertWeightsAlias:
+ """TileRT weights alias for ProjqWqb."""
+
+ wkv_b_weights = "wkv_b1_weights"
+ wkv_b_scales = "wkv_b1_scales"
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return [self.wkv_b_weights, self.wkv_b_scales]
+
+ def __call__(self) -> list[str]:
+ return self.tilert_tensor_alias
+
+
+class ProjqWqb(TileRTModule):
+ """ProjqWqb module: Q projection (wkv_b) for KV LoRA."""
+
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [ProjqWqbAlgorithm.FP16MMA],
+ "glm_5": [ProjqWqbAlgorithm.FP16MMA],
+ }
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ num_devices: int,
+ device_id: int = 0,
+ ref_weights_alias: ProjqWqbRefWeightsAlias | None = None,
+ ):
+ super().__init__(
+ self.__class__.__name__,
+ model_args=model_args,
+ num_devices=num_devices,
+ device_id=device_id,
+ )
+
+ self.tilert_weights_alias = ProjqWqbTilertWeightsAlias()
+ self.ref_weights_alias = (
+ ref_weights_alias if ref_weights_alias is not None else ProjqWqbRefWeightsAlias()
+ )
+
+ self.ref_wkv_b: torch.Tensor | None = None
+ self.tilert_wkv_b_a: torch.Tensor | None = None
+ self.tilert_wkv_b_a_scales: torch.Tensor | None = None
+ self.output: torch.Tensor | None = None
+ self.profile_logs: torch.Tensor | None = None
+
+ self.compute_kernel_type = "fp16mma"
+
+ if self.model_args.n_heads % self.num_devices == 0:
+ self.num_local_heads = self.model_args.n_heads // self.num_devices
+ else:
+ n_local = math.ceil(self.model_args.n_heads / self.num_devices)
+ if n_local % 2 != 0:
+ n_local += 1
+ self.num_local_heads = n_local
+
+ self.wkvb_lora_rank = self.model_args.kv_lora_rank
+ self.wkvb_lora_rank_qsize = self.wkvb_lora_rank // self.model_args.block_size
+
+ self.wkvb_head_dim = self.model_args.qk_nope_head_dim + self.model_args.v_head_dim
+ self.wkvb_nope_head_dim = self.model_args.qk_nope_head_dim
+ left_head_dim = self.wkvb_head_dim % self.model_args.block_size
+ if left_head_dim != 0:
+ assert self.model_args.block_size % left_head_dim == 0
+ self.head_dim_block_size = left_head_dim
+ self.head_dim_scale_repeat = self.model_args.block_size // self.head_dim_block_size
+ else:
+ self.head_dim_scale_repeat = 1
+ self.head_dim_block_size = self.model_args.block_size
+ self.wkvb_head_qsize = self.wkvb_head_dim // self.head_dim_block_size
+ self.wkvb_nope_head_qsize = self.wkvb_nope_head_dim // self.head_dim_block_size
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return self.tilert_weights_alias.tilert_tensor_alias
+
+ def get_weights_list(self) -> list[torch.Tensor]:
+ return [self.tilert_wkv_b_a, self.tilert_wkv_b_a_scales]
+
+ def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ """
+ Device sharding: split weights and scales per device.
+
+ Args:
+ weights_map: Map from ref weight alias to tensor.
+
+ Returns:
+ Map from tilert weight alias to (num_devices, ...) tensors.
+ """
+ kv_b_proj_weight = weights_map[self.ref_weights_alias.wkv_b_weights]
+ kv_b_proj_weight_scale = weights_map[self.ref_weights_alias.wkv_b_scales]
+
+ if self.model_args.n_heads % self.num_devices == 0:
+ dev_weights = kv_b_proj_weight.view(
+ self.num_devices, self.num_local_heads, self.wkvb_head_dim, self.wkvb_lora_rank
+ )
+ dev_scale_rows = self.num_local_heads * self.wkvb_head_dim // self.model_args.block_size
+ dev_scales = kv_b_proj_weight_scale.view(
+ self.num_devices, dev_scale_rows, 1, self.wkvb_lora_rank_qsize
+ )
+ else:
+ from tilert.models.glm_5._dsa_v32.ops.rmsnorm_projq_wqb import (
+ RmsnormProjqWqbWeightsConverter,
+ )
+
+ wq_b_list, scale_list = RmsnormProjqWqbWeightsConverter._redistribute_heads(
+ kv_b_proj_weight,
+ kv_b_proj_weight_scale,
+ n_total_heads=self.model_args.n_heads,
+ n_local_heads=self.num_local_heads,
+ num_devices=self.num_devices,
+ qk_head_dim=self.wkvb_head_dim,
+ block_size=self.model_args.block_size,
+ )
+ dev_weights = torch.stack(wq_b_list, dim=0).view(
+ self.num_devices, self.num_local_heads, self.wkvb_head_dim, self.wkvb_lora_rank
+ )
+ dev_scale_rows = self.num_local_heads * self.wkvb_head_dim // self.model_args.block_size
+ dev_scales = torch.stack(scale_list, dim=0).view(
+ self.num_devices, dev_scale_rows, 1, self.wkvb_lora_rank_qsize
+ )
+
+ wkvb = dev_weights[:, :, : self.wkvb_nope_head_dim]
+ wkvb_scales = (
+ dev_scales.contiguous()
+ .repeat(1, 1, self.head_dim_scale_repeat, 1)
+ .view(
+ self.num_devices,
+ self.num_local_heads,
+ self.wkvb_head_qsize,
+ self.wkvb_lora_rank_qsize,
+ )
+ .contiguous()[:, :, : self.wkvb_nope_head_qsize]
+ )
+ return {
+ self.tilert_weights_alias.wkv_b_weights: wkvb.contiguous(),
+ self.tilert_weights_alias.wkv_b_scales: wkvb_scales.contiguous(),
+ }
+
+ def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ sharding_size = self.num_local_heads * self.wkvb_head_dim
+ sharding_start = self.device_id * sharding_size
+ sharding_end = sharding_start + sharding_size
+ wkv_b = weight_dequant(
+ state_dict[self.ref_weights_alias.wkv_b_weights],
+ state_dict[self.ref_weights_alias.wkv_b_scales],
+ )
+ wkv_b = wkv_b[sharding_start:sharding_end, :]
+ wkv_b = wkv_b.view(self.num_local_heads, self.wkvb_head_dim, self.wkvb_lora_rank)
+ self.ref_wkv_b = wkv_b[:, : self.wkvb_nope_head_dim]
+
+ def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ self.init_tilert_weights_hmma(state_dict)
+
+ def init_tilert_weights_hmma(self, state_dict: dict[str, torch.Tensor]) -> None:
+ """Initialize with HMMA-packed weights."""
+ packed = ProjqWqbWeightsConverter(
+ self.model_args, self.num_devices, self.head_dim_block_size
+ ).dispatch(
+ ProjqWqbAlgorithm.FP16MMA,
+ [
+ state_dict[self.tilert_weights_alias.wkv_b_weights],
+ state_dict[self.tilert_weights_alias.wkv_b_scales],
+ ],
+ )
+ self.tilert_wkv_b_a = packed
+ self.tilert_wkv_b_a_scales = torch.empty(1, dtype=torch.float8_e4m3fn, device=packed.device)
+ self.compute_kernel_type = "fp16mma"
+
+ def init_tilert_weights_hmma_bf16(self, state_dict: dict[str, torch.Tensor]) -> None:
+ """Initialize with BF16 HMMA-packed weights (dequantized, no scales)."""
+ packed = ProjqWqbWeightsConverter(
+ self.model_args, self.num_devices, self.head_dim_block_size
+ ).dispatch(
+ ProjqWqbAlgorithm.BF16MMA,
+ [
+ state_dict[self.tilert_weights_alias.wkv_b_weights],
+ state_dict[self.tilert_weights_alias.wkv_b_scales],
+ ],
+ )
+ self.tilert_wkv_b_a = packed
+ self.tilert_wkv_b_a_scales = torch.empty(1, dtype=torch.float8_e4m3fn, device=packed.device)
+ self.compute_kernel_type = "bf16mma"
+
+ def init_random_weights(self) -> None:
+ padded_total_heads = self.num_local_heads * self.num_devices
+ wkv_b = init_func(
+ torch.empty(
+ padded_total_heads * self.wkvb_head_dim,
+ self.wkvb_lora_rank,
+ dtype=torch.float8_e4m3fn,
+ )
+ )
+ wkv_b_scales = init_func(
+ torch.empty(
+ padded_total_heads * self.wkvb_head_dim // self.model_args.block_size,
+ self.wkvb_lora_rank_qsize,
+ dtype=torch.float32,
+ )
+ )
+ ref_state_dict = dict(zip(self.ref_weights_alias(), [wkv_b, wkv_b_scales]))
+ self.init_reference_weights(ref_state_dict)
+ sharded = self.device_sharding(ref_state_dict)
+ self.init_tilert_weights({k: v[self.device_id] for k, v in sharded.items()})
+
+ def init_tilert_vars(self, batch_size: int, seq_len: int) -> None:
+ self.output = torch.zeros(
+ (batch_size, seq_len, self.num_local_heads, self.wkvb_lora_rank), dtype=torch.bfloat16
+ )
+ self.profile_logs = get_profile_log_tensor()
+ self.is_var_init = True
+
+ def golden_forward(self, q_nope: torch.Tensor) -> torch.Tensor:
+ assert self.ref_wkv_b is not None
+ return torch.einsum("bshd,hdc->bshc", q_nope, self.ref_wkv_b)
+
+ def tilert_forward(self, q_nope: torch.Tensor) -> torch.Tensor:
+ assert self.tilert_wkv_b_a is not None
+ assert self.tilert_wkv_b_a_scales is not None
+ assert self.output is not None
+ assert self.profile_logs is not None
+ projq_wqb(
+ q_nope,
+ self.tilert_wkv_b_a,
+ self.tilert_wkv_b_a_scales,
+ self.output,
+ self.profile_logs,
+ self.compute_kernel_type,
+ model_arch=self.model_args.arch_name,
+ )
+ return self.output
diff --git a/tilert/models/glm_5/_dsa_v32/ops/projx_wis.py b/tilert/models/glm_5/_dsa_v32/ops/projx_wis.py
new file mode 100644
index 0000000..e13b4e0
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/ops/projx_wis.py
@@ -0,0 +1,211 @@
+"""ProjxWis operation module."""
+
+from dataclasses import dataclass
+from enum import Enum
+
+import torch
+
+from tilert.models.base import TileRTModule
+from tilert.models.common import init_func
+from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
+from tilert.utils import get_profile_log_tensor
+
+__all__ = [
+ "projx_wis",
+ "ProjxWis",
+ "ProjxWisRefWeightsAlias",
+ "ProjxWisTilertWeightsAlias",
+]
+
+
+def projx_wis(
+ x_in: torch.Tensor,
+ w: torch.Tensor,
+ output: torch.Tensor,
+ compute_kernel_type: str,
+ profile_logs: torch.Tensor,
+ model_arch: str,
+) -> None:
+ """
+ Define the ProjxWis operation.
+
+ Args:
+ x_in: Input tensor.
+ w: Weight tensor.
+ output: Output tensor.
+ compute_kernel_type: Compute kernel type ("bf16" or "bf16mma").
+ profile_logs: Profile logs tensor.
+ model_arch: Model architecture ("deepseek_v3_2" or "glm_5").
+ """
+ torch.ops.tilert.proj_w_op(x_in, w, output, model_arch, compute_kernel_type, profile_logs)
+
+
+@dataclass
+class ProjxWisRefWeightsAlias:
+ """Reference weights alias for ProjxWis."""
+
+ w_weights = "self_attn.indexer.weights_proj.weight"
+
+ @property
+ def ref_tensor_alias(self) -> list[str]:
+ return [self.w_weights]
+
+ def __call__(self) -> list[str]:
+ return self.ref_tensor_alias
+
+
+@dataclass
+class ProjxWisTilertWeightsAlias:
+ """TileRT weights alias for ProjxWis."""
+
+ w_weights = "id_score_weights"
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return [self.w_weights]
+
+ def __call__(self) -> list[str]:
+ return self.tilert_tensor_alias
+
+
+class ProjxWisAlgorithm(Enum):
+ """ProjxWis algorithm."""
+
+ BF16 = "bf16"
+ BF16MMA = "bf16mma"
+
+
+class ProjxWis(TileRTModule):
+ """ProjxWis module: linear projection for indexer score weights."""
+
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [ProjxWisAlgorithm.BF16, ProjxWisAlgorithm.BF16MMA],
+ "glm_5": [ProjxWisAlgorithm.BF16, ProjxWisAlgorithm.BF16MMA],
+ }
+
+ _HMMA_CONFIGS = {
+ 7168: (4, 16, 7),
+ 6144: (2, 16, 6),
+ }
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ num_devices: int,
+ device_id: int = 0,
+ ref_weights_alias: ProjxWisRefWeightsAlias | None = None,
+ compute_kernel_type: str | None = None,
+ ):
+ super().__init__(
+ self.__class__.__name__,
+ model_args=model_args,
+ num_devices=num_devices,
+ device_id=device_id,
+ )
+
+ self.tilert_weights_alias = ProjxWisTilertWeightsAlias()
+ self.ref_weights_alias = (
+ ref_weights_alias if ref_weights_alias is not None else ProjxWisRefWeightsAlias()
+ )
+
+ self.ref_tensor_alias = self.ref_weights_alias.ref_tensor_alias
+
+ self.ref_w: torch.Tensor | None = None
+ self.tilert_w: torch.Tensor | None = None
+ self.output: torch.Tensor | None = None
+ self.profile_logs: torch.Tensor | None = None
+
+ self.dim = model_args.dim
+ self.index_n_heads = model_args.index_n_heads
+
+ if compute_kernel_type is not None:
+ self.compute_kernel_type = compute_kernel_type
+ else:
+ self.compute_kernel_type = "bf16"
+
+ @staticmethod
+ def _swizzle_mma_16x16(mat_in: torch.Tensor) -> torch.Tensor:
+ """Swizzle each 16x16 BF16 tile for the packed weight layout."""
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 16
+ pre_shape = mat_in.shape[:-2]
+ mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 2).transpose(-4, -3).transpose(-5, -4)
+ return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 2).transpose(-3, -2)
+
+ @staticmethod
+ def _to_hmma_layout(
+ w_orig: torch.Tensor, n_ctas: int, rows_per_cta: int, x_dim: int, num_pages: int
+ ) -> torch.Tensor:
+ """Convert [output_dim, x_dim] BF16 weights to the packed kernel layout."""
+ cols_per_page = x_dim // num_pages
+ n_k_tiles = cols_per_page // 16
+ w = w_orig.reshape(n_ctas, rows_per_cta, num_pages, cols_per_page)
+ w = w.transpose(1, 2)
+ n_row_tiles = rows_per_cta // 16
+ w = w.reshape(n_ctas, num_pages, n_row_tiles, 16, n_k_tiles, 16).transpose(-3, -2)
+ w = ProjxWis._swizzle_mma_16x16(w)
+ return w.reshape(n_ctas, -1).contiguous()
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return self.tilert_weights_alias.tilert_tensor_alias
+
+ def get_weights_list(self) -> list[torch.Tensor]:
+ return [self.tilert_w]
+
+ def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ """
+ Device sharding: replicate weight for each device.
+
+ Args:
+ weights_map: Map from ref weight alias to tensor.
+
+ Returns:
+ Map from tilert weight alias to (num_devices, ...) tensors.
+ """
+ w = weights_map[self.ref_weights_alias.w_weights]
+ if self.compute_kernel_type == "bf16mma":
+ n_ctas, rows_per_cta, num_pages = self._HMMA_CONFIGS[self.dim]
+ w_hmma = self._to_hmma_layout(w, n_ctas, rows_per_cta, self.dim, num_pages)
+ w_out = w_hmma[None, ...].repeat(self.num_devices, 1, 1)
+ else:
+ w_out = w[None, ...].repeat(self.num_devices, 1, 1)
+ return {self.tilert_weights_alias.w_weights: w_out}
+
+ def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ w = state_dict[self.ref_weights_alias.w_weights]
+ self.ref_w = w.detach().clone().to(torch.bfloat16)
+ self.is_ref_weights_init = True
+
+ def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ self.tilert_w = state_dict[self.tilert_weights_alias.w_weights].detach().clone()
+ self.is_tilert_weights_init = True
+
+ def init_random_weights(self) -> None:
+ ref_w = init_func(torch.empty(self.index_n_heads, self.dim, dtype=torch.bfloat16))
+ ref_state_dict = dict(zip(self.ref_weights_alias(), [ref_w]))
+ self.init_reference_weights(ref_state_dict)
+ sharded = self.device_sharding(ref_state_dict)
+ self.init_tilert_weights({k: v[self.device_id] for k, v in sharded.items()})
+
+ def init_tilert_vars(self, batch_size: int, seq_len: int) -> None:
+ self.output = torch.zeros((batch_size, seq_len, self.index_n_heads), dtype=torch.bfloat16)
+ self.profile_logs = get_profile_log_tensor()
+ self.is_var_init = True
+
+ def golden_forward(self, x_norm: torch.Tensor) -> torch.Tensor:
+ assert self.ref_w is not None
+ return torch.nn.functional.linear(x_norm, self.ref_w)
+
+ def tilert_forward(self, x_norm: torch.Tensor) -> torch.Tensor:
+ assert self.tilert_w is not None
+ assert self.output is not None
+ assert self.profile_logs is not None
+ projx_wis(
+ x_norm,
+ self.tilert_w,
+ self.output,
+ self.compute_kernel_type,
+ self.profile_logs,
+ model_arch=self.model_args.arch_name,
+ )
+ return self.output
diff --git a/tilert/models/glm_5/_dsa_v32/ops/projx_wqaki.py b/tilert/models/glm_5/_dsa_v32/ops/projx_wqaki.py
new file mode 100644
index 0000000..367d5fe
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/ops/projx_wqaki.py
@@ -0,0 +1,247 @@
+"""ProjxWqaki operation module."""
+
+import torch
+
+__all__ = [
+ "projx_wqaki",
+ "ProjxWqakiWeightsConverter",
+]
+
+
+def projx_wqaki(
+ x_quant: torch.Tensor,
+ x_scale: torch.Tensor,
+ wqaki: torch.Tensor,
+ out_q: torch.Tensor,
+ out_ki: torch.Tensor,
+ profile_logs: torch.Tensor,
+ compute_kernel_type: str = "fp8mma",
+ *,
+ model_arch: str,
+) -> None:
+ """FP8 projection for q, ki.
+
+ Args:
+ x_quant: FP8 quantized hidden states [1, seq_len, hidden_dim].
+ x_scale: Scale factors for x_quant.
+ wqaki: Packed FP8 weights + scales for q, ki.
+ out_q: Output q tensor.
+ out_ki: Output ki tensor.
+ profile_logs: Profile logs tensor.
+ compute_kernel_type: Kernel type ("fp8mma", "fp8mma_68cta", "fp8mma_136cta").
+ model_arch: Model architecture ("deepseek_v3_2" or "glm_5").
+ """
+ torch.ops.tilert.projx_wqaki_op(
+ x_quant,
+ x_scale,
+ wqaki,
+ out_q,
+ out_ki,
+ model_arch,
+ compute_kernel_type,
+ profile_logs,
+ torch.empty(0, dtype=torch.int64, device=x_quant.device),
+ )
+
+
+class ProjxWqakiWeightsConverter:
+ """Weight converter for ProjxWqaki kernel."""
+
+ @staticmethod
+ def _swizzle_qmma_16x32(mat_in: torch.Tensor) -> torch.Tensor:
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 32
+ assert mat_in.dtype == torch.float8_e4m3fn
+ pre_shape = mat_in.shape[:-2]
+ mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 4).transpose(-4, -3).transpose(-5, -4)
+ return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 4).transpose(-3, -2)
+
+ @staticmethod
+ def convert_dsv32(
+ wq_a: torch.Tensor,
+ wq_a_scale: torch.Tensor,
+ wki: torch.Tensor,
+ wki_scale: torch.Tensor,
+ ) -> torch.Tensor:
+ """Convert DSV3.2 weights to the packed format expected by the kernel."""
+ with torch.inference_mode():
+ wq_a_scale = wq_a_scale.to(torch.bfloat16)
+ wki_scale = wki_scale.to(torch.bfloat16)
+
+ dim = 7168
+ q_rows = 1536
+ ki_rows = 128
+ total_rows = q_rows + ki_rows
+ n_blocks = total_rows // 16
+ scale_dim = dim // 128
+
+ n_q_blocks = q_rows // 16
+ n_ki_blocks = ki_rows // 16
+ wq_a = wq_a.reshape(n_q_blocks, 16, dim)
+ wq_a_scale = (
+ wq_a_scale.reshape(wq_a_scale.shape[0], 1, scale_dim)
+ .repeat(1, n_q_blocks // wq_a_scale.shape[0], 1)
+ .reshape(n_q_blocks, scale_dim)
+ )
+ wki = wki.reshape(n_ki_blocks, 16, dim)
+ wki_scale = (
+ wki_scale.reshape(wki_scale.shape[0], 1, scale_dim)
+ .repeat(1, n_ki_blocks // wki_scale.shape[0], 1)
+ .reshape(n_ki_blocks, scale_dim)
+ )
+
+ wqaki = torch.cat([wq_a, wki], dim=0)
+ wqaki_scale = torch.cat([wq_a_scale, wki_scale], dim=0)
+
+ swizzle = ProjxWqakiWeightsConverter._swizzle_qmma_16x32
+
+ wqaki_0 = wqaki[..., :2048]
+ wqaki_0_scale = wqaki_scale[..., :16].contiguous().view(torch.float8_e4m3fn)
+ wqaki_1 = wqaki[..., 2048:4096]
+ wqaki_1_scale = wqaki_scale[..., 16:32].contiguous().view(torch.float8_e4m3fn)
+ wqaki_2 = wqaki[..., 4096:6144]
+ wqaki_2_scale = wqaki_scale[..., 32:48].contiguous().view(torch.float8_e4m3fn)
+ wqaki_3 = wqaki[..., 6144:7168]
+ wqaki_3_scale = wqaki_scale[..., 48:56].contiguous().view(torch.float8_e4m3fn)
+
+ wqaki_0 = wqaki_0.reshape(n_blocks, 16, 64, 32).transpose(1, 2)
+ wqaki_0 = swizzle(wqaki_0).reshape(n_blocks, 16 * 2048)
+
+ wqaki_1 = wqaki_1.reshape(n_blocks, 16, 64, 32).transpose(1, 2)
+ wqaki_1 = swizzle(wqaki_1).reshape(n_blocks, 16 * 2048)
+
+ wqaki_2 = wqaki_2.reshape(n_blocks, 16, 64, 32).transpose(1, 2)
+ wqaki_2 = swizzle(wqaki_2).reshape(n_blocks, 16 * 2048)
+
+ wqaki_3 = wqaki_3.reshape(n_blocks, 16, 32, 32).transpose(1, 2)
+ wqaki_3 = swizzle(wqaki_3).reshape(n_blocks, 16 * 1024)
+
+ padding_scale0 = torch.zeros(
+ (n_blocks, 48), dtype=torch.bfloat16, device=wq_a.device
+ ).view(torch.float8_e4m3fn)
+ padding_scale1 = torch.zeros(
+ (n_blocks, 48), dtype=torch.bfloat16, device=wq_a.device
+ ).view(torch.float8_e4m3fn)
+ padding_scale2 = torch.zeros(
+ (n_blocks, 48), dtype=torch.bfloat16, device=wq_a.device
+ ).view(torch.float8_e4m3fn)
+ padding_scale3 = torch.zeros(
+ (n_blocks, 56), dtype=torch.bfloat16, device=wq_a.device
+ ).view(torch.float8_e4m3fn)
+
+ return torch.cat(
+ [
+ wqaki_0,
+ wqaki_0_scale,
+ padding_scale0,
+ wqaki_1,
+ wqaki_1_scale,
+ padding_scale1,
+ wqaki_2,
+ wqaki_2_scale,
+ padding_scale2,
+ wqaki_3,
+ wqaki_3_scale,
+ padding_scale3,
+ ],
+ dim=1,
+ ).contiguous()
+
+ @staticmethod
+ def convert_glm5_68cta(
+ wq_a: torch.Tensor,
+ wq_a_scale: torch.Tensor,
+ wki: torch.Tensor,
+ wki_scale: torch.Tensor,
+ ) -> torch.Tensor:
+ """Convert GLM5 weights to the packed format expected by the kernel."""
+ with torch.inference_mode():
+ wq_a_scale = wq_a_scale.to(torch.float32)
+ wki_scale = wki_scale.to(torch.float32)
+
+ dim = 6144
+ q_rows = 2048
+ ki_rows = 128
+ total_rows = q_rows + ki_rows
+ n_blocks = total_rows // 32
+ scale_dim = dim // 128
+
+ n_q_blocks = q_rows // 32
+ n_ki_blocks = ki_rows // 32
+
+ wqaki_raw = torch.cat([wq_a, wki], dim=0).reshape(n_blocks, 32, dim)
+
+ wq_a_scale = (
+ wq_a_scale.reshape(wq_a_scale.shape[0], 1, scale_dim)
+ .repeat(1, n_q_blocks // wq_a_scale.shape[0], 1)
+ .reshape(n_q_blocks, scale_dim)
+ )
+ wki_scale = (
+ wki_scale.reshape(wki_scale.shape[0], 1, scale_dim)
+ .repeat(1, n_ki_blocks // wki_scale.shape[0], 1)
+ .reshape(n_ki_blocks, scale_dim)
+ )
+ wqaki_scales = torch.cat([wq_a_scale, wki_scale], dim=0)
+
+ swizzle = ProjxWqakiWeightsConverter._swizzle_qmma_16x32
+
+ wqaki_raw = wqaki_raw.reshape(n_blocks, 32, 6, 1024).transpose(1, 2)
+ wqaki_raw = wqaki_raw.reshape(n_blocks, 6, 2, 16, 32, 32).transpose(3, 4)
+ wqaki_raw = swizzle(wqaki_raw).reshape(n_blocks, 6, 32 * 1024)
+ wqaki_scales = wqaki_scales.reshape(n_blocks, 6, 8).view(torch.float8_e4m3fn)
+ wqaki_padding = torch.zeros(
+ (n_blocks, 6, 128 - wqaki_scales.shape[-1]),
+ dtype=torch.float8_e4m3fn,
+ device=wq_a.device,
+ )
+ return torch.cat([wqaki_raw, wqaki_scales, wqaki_padding], dim=-1).contiguous()
+
+ @staticmethod
+ def convert_glm5_136cta(
+ wq_a: torch.Tensor,
+ wq_a_scale: torch.Tensor,
+ wki: torch.Tensor,
+ wki_scale: torch.Tensor,
+ ) -> torch.Tensor:
+ """Convert GLM5 weights to the packed format expected by the kernel."""
+ with torch.inference_mode():
+ wq_a_scale = wq_a_scale.to(torch.float32)
+ wki_scale = wki_scale.to(torch.float32)
+
+ dim = 6144
+ q_rows = 2048
+ ki_rows = 128
+ total_rows = q_rows + ki_rows
+ n_blocks = total_rows // 16
+ scale_dim = dim // 128
+
+ n_q_blocks = q_rows // 16
+ n_ki_blocks = ki_rows // 16
+
+ wq_a = wq_a.reshape(n_q_blocks, 16, dim)
+ wq_a_scale = (
+ wq_a_scale.reshape(wq_a_scale.shape[0], 1, scale_dim)
+ .repeat(1, n_q_blocks // wq_a_scale.shape[0], 1)
+ .reshape(n_q_blocks, scale_dim)
+ )
+ wki = wki.reshape(n_ki_blocks, 16, dim)
+ wki_scale = (
+ wki_scale.reshape(wki_scale.shape[0], 1, scale_dim)
+ .repeat(1, n_ki_blocks // wki_scale.shape[0], 1)
+ .reshape(n_ki_blocks, scale_dim)
+ )
+
+ wqaki_raw = torch.cat([wq_a, wki], dim=0)
+ wqaki_scales = torch.cat([wq_a_scale, wki_scale], dim=0)
+
+ swizzle = ProjxWqakiWeightsConverter._swizzle_qmma_16x32
+
+ wqaki_raw = wqaki_raw.reshape(n_blocks, 16, 3, 2048).transpose(1, 2)
+ wqaki_raw = wqaki_raw.reshape(n_blocks, 3, 1, 16, 64, 32).transpose(3, 4)
+ wqaki_raw = swizzle(wqaki_raw).reshape(n_blocks, 3, 16 * 2048)
+ wqaki_scales = wqaki_scales.reshape(n_blocks, 3, 16).view(torch.float8_e4m3fn)
+ wqaki_padding = torch.zeros(
+ (n_blocks, 3, 128 - wqaki_scales.shape[-1]),
+ dtype=torch.float8_e4m3fn,
+ device=wq_a.device,
+ )
+ return torch.cat([wqaki_raw, wqaki_scales, wqaki_padding], dim=-1).contiguous()
diff --git a/tilert/models/glm_5/_dsa_v32/ops/projx_wqkva.py b/tilert/models/glm_5/_dsa_v32/ops/projx_wqkva.py
new file mode 100644
index 0000000..6ade7af
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/ops/projx_wqkva.py
@@ -0,0 +1,330 @@
+"""ProjXWqkva operation module."""
+
+from enum import Enum
+
+import torch
+
+from tilert.models.base import TileRTModule
+from tilert.models.common import weight_dequant
+from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
+from tilert.models.glm_5._dsa_v32.ops.rmsnorm_projx_wqkva import (
+ RMSNormProjQKVAFP8MMAWeightsConverter,
+ RMSNormProjQKVAFP16MMAWeightsConverter,
+)
+from tilert.utils import get_profile_log_tensor
+
+__all__ = [
+ "ProjXWqkva",
+ "projx_wqkva",
+]
+
+
+def projx_wqkva(
+ x_quant: torch.Tensor,
+ x_scale: torch.Tensor,
+ wqkva: torch.Tensor,
+ cur_pos: torch.Tensor,
+ q_out: torch.Tensor,
+ kv_out: torch.Tensor,
+ pe_cache_out: torch.Tensor,
+ profile_logs: torch.Tensor,
+ compute_kernel_type: str = "fp8mma",
+ *,
+ model_arch: str,
+) -> None:
+ """FP8 MMA projection for q, kv, pe_cache (DSV3.2)."""
+ torch.ops.tilert.projx_wqkva_op(
+ x_quant,
+ x_scale,
+ wqkva,
+ cur_pos,
+ q_out,
+ kv_out,
+ pe_cache_out,
+ model_arch,
+ compute_kernel_type,
+ profile_logs,
+ torch.empty(0, dtype=torch.int64, device=x_quant.device),
+ )
+
+
+class ProjXWqkvaRefWeightsAlias:
+ """Reference weight aliases for ProjXWqkva."""
+
+ x_rmsnorm_gamma = "input_layernorm.weight"
+ q_a_weights = "self_attn.q_a_proj.weight"
+ q_a_scales = "self_attn.q_a_proj.weight_scale_inv"
+ kv_a_weights = "self_attn.kv_a_proj_with_mqa.weight"
+ kv_a_scales = "self_attn.kv_a_proj_with_mqa.weight_scale_inv"
+
+ @property
+ def ref_tensor_alias(self) -> list[str]:
+ return [
+ self.x_rmsnorm_gamma,
+ self.q_a_weights,
+ self.q_a_scales,
+ self.kv_a_weights,
+ self.kv_a_scales,
+ ]
+
+ def __call__(self) -> list[str]:
+ return self.ref_tensor_alias
+
+
+class ProjXWqkvaTilertWeightsAlias:
+ """Tilert weight aliases for ProjXWqkva."""
+
+ q_a_weights = "q_a_weights"
+ q_a_scales = "q_a_scales"
+ kv_a_weights = "kv_a_weights"
+ kv_a_scales = "kv_a_scales"
+ w_pe_weights = "w_pe_weights"
+ w_pe_scales = "w_pe_scales"
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return [
+ self.q_a_weights,
+ self.q_a_scales,
+ self.kv_a_weights,
+ self.kv_a_scales,
+ self.w_pe_weights,
+ self.w_pe_scales,
+ ]
+
+ def __call__(self) -> list[str]:
+ return self.tilert_tensor_alias
+
+
+class ProjXWqkvaAlgorithm(Enum):
+ """ProjXWqkva algorithm."""
+
+ FP8MMA = "fp8mma"
+ FP16MMA = "fp16mma"
+
+
+class ProjXWqkva(TileRTModule):
+ """FP8 MMA projection module for q, kv, pe_cache."""
+
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [ProjXWqkvaAlgorithm.FP8MMA],
+ "glm_5": [ProjXWqkvaAlgorithm.FP8MMA, ProjXWqkvaAlgorithm.FP16MMA],
+ }
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ num_devices: int,
+ device_id: int,
+ ref_weights_alias: ProjXWqkvaRefWeightsAlias | None = None,
+ ):
+ super().__init__(
+ self.__class__.__name__,
+ model_args=model_args,
+ num_devices=num_devices,
+ device_id=device_id,
+ )
+
+ self.tilert_weights_alias = ProjXWqkvaTilertWeightsAlias()
+ self.ref_weights_alias = (
+ ref_weights_alias if ref_weights_alias is not None else ProjXWqkvaRefWeightsAlias()
+ )
+
+ self.dim = self.model_args.dim
+ self.q_lora_rank = self.model_args.q_lora_rank
+ self.kv_lora_rank = self.model_args.kv_lora_rank
+ self.qk_rope_head_dim = self.model_args.qk_rope_head_dim
+ self.block_size = self.model_args.block_size
+ self.eps = self.model_args.eps
+
+ self.ref_wq_a: torch.Tensor | None = None
+ self.ref_wkv_a: torch.Tensor | None = None
+ self.ref_w_pe: torch.Tensor | None = None
+
+ self.tilert_wqkva: torch.Tensor | None = None
+
+ self.q_out: torch.Tensor | None = None
+ self.kv_out: torch.Tensor | None = None
+ self.pe_cache_out: torch.Tensor | None = None
+ self.cur_pos: torch.Tensor | None = None
+ self.profile_logs: torch.Tensor | None = None
+ self.is_init = False
+
+ self.compute_kernel_type = "fp8mma"
+
+ def set_algorithm(self, algorithm: Enum) -> None:
+ super().set_algorithm(algorithm)
+ if algorithm == ProjXWqkvaAlgorithm.FP16MMA:
+ self.compute_kernel_type = "fp16mma"
+ else:
+ self.compute_kernel_type = "fp8mma"
+
+ def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ """Repeat weights for device sharding."""
+ q_a_proj_weight = weights_map[self.ref_weights_alias.q_a_weights][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ q_a_proj_weight_scale = weights_map[self.ref_weights_alias.q_a_scales][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ kv_a_mqa = weights_map[self.ref_weights_alias.kv_a_weights]
+ kv_a_proj_weight = kv_a_mqa[: self.kv_lora_rank, :][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ w_pe_weight = kv_a_mqa[self.kv_lora_rank :, :][None, ...].repeat(self.num_devices, 1, 1)
+ kv_a_mqa_scale = weights_map[self.ref_weights_alias.kv_a_scales]
+ kv_scale_rows = (self.kv_lora_rank + self.block_size - 1) // self.block_size
+ kv_a_proj_weight_scale = kv_a_mqa_scale[:kv_scale_rows, :][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ w_pe_weight_scale = kv_a_mqa_scale[kv_scale_rows:, :][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ return {
+ self.tilert_weights_alias.q_a_weights: q_a_proj_weight,
+ self.tilert_weights_alias.q_a_scales: q_a_proj_weight_scale,
+ self.tilert_weights_alias.kv_a_weights: kv_a_proj_weight,
+ self.tilert_weights_alias.kv_a_scales: kv_a_proj_weight_scale,
+ self.tilert_weights_alias.w_pe_weights: w_pe_weight,
+ self.tilert_weights_alias.w_pe_scales: w_pe_weight_scale,
+ }
+
+ def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ aliases = self.ref_weights_alias()
+ self.ref_wq_a = weight_dequant(state_dict[aliases[1]], state_dict[aliases[2]])
+ kv_a_mqa = weight_dequant(state_dict[aliases[3]], state_dict[aliases[4]])
+ self.ref_wkv_a = kv_a_mqa[: self.kv_lora_rank, :]
+ self.ref_w_pe = kv_a_mqa[self.kv_lora_rank :, :]
+
+ assert self.ref_wq_a.shape == (self.q_lora_rank, self.dim)
+ assert self.ref_wkv_a.shape == (self.kv_lora_rank, self.dim)
+ assert self.ref_w_pe.shape == (self.qk_rope_head_dim, self.dim)
+
+ def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ tilert_aliases = self.tilert_weights_alias()
+ wq_a = state_dict[tilert_aliases[0]]
+ wq_a_scale = state_dict[tilert_aliases[1]]
+ wkv_a = state_dict[tilert_aliases[2]]
+ wkv_a_scale = state_dict[tilert_aliases[3]]
+ w_pe = state_dict[tilert_aliases[4]]
+ w_pe_scale = state_dict[tilert_aliases[5]]
+ dummy_gamma = torch.zeros(self.dim, dtype=torch.float32, device=wq_a.device)
+
+ if self.algorithm == ProjXWqkvaAlgorithm.FP16MMA:
+ self.tilert_wqkva, _ = RMSNormProjQKVAFP16MMAWeightsConverter.convert_to_fp16_mma_gemv(
+ wq_a,
+ wq_a_scale,
+ wkv_a,
+ wkv_a_scale,
+ w_pe,
+ w_pe_scale,
+ dummy_gamma,
+ hidden_dim=self.dim,
+ q_lora_rank=self.q_lora_rank,
+ )
+ else:
+ self.tilert_wqkva, _ = RMSNormProjQKVAFP8MMAWeightsConverter.convert_to_fp8_mma_gemv(
+ wq_a,
+ wq_a_scale,
+ wkv_a,
+ wkv_a_scale,
+ w_pe,
+ w_pe_scale,
+ dummy_gamma,
+ hidden_dim=self.dim,
+ q_lora_rank=self.q_lora_rank,
+ )
+
+ def init_tilert_vars(self, batch_size: int, seq_len: int, max_len: int = 128) -> None:
+ self.q_out = torch.zeros((batch_size, seq_len, self.q_lora_rank), dtype=torch.bfloat16)
+ self.kv_out = torch.zeros((batch_size, seq_len, self.kv_lora_rank), dtype=torch.bfloat16)
+ self.pe_cache_out = torch.zeros(
+ (batch_size, max_len, self.qk_rope_head_dim), dtype=torch.bfloat16
+ )
+ self.cur_pos = torch.zeros((1,), dtype=torch.int32)
+ self.profile_logs = get_profile_log_tensor()
+ self.is_init = True
+
+ def init_random_weights(self) -> None:
+ bs = self.block_size
+ dim_scale_dim = self.dim // bs
+ q_scale_dim = (self.q_lora_rank + bs - 1) // bs
+ kv_mqa_rows = self.kv_lora_rank + self.qk_rope_head_dim
+ kv_mqa_scale_dim = (kv_mqa_rows + bs - 1) // bs
+ scale_dtype = torch.bfloat16
+
+ tensor_list = [
+ torch.randn(self.dim, dtype=torch.float32),
+ torch.randn(self.q_lora_rank, self.dim, dtype=torch.bfloat16).to(torch.float8_e4m3fn),
+ torch.randn(q_scale_dim, dim_scale_dim, dtype=scale_dtype),
+ torch.randn(kv_mqa_rows, self.dim, dtype=torch.bfloat16).to(torch.float8_e4m3fn),
+ torch.randn(kv_mqa_scale_dim, dim_scale_dim, dtype=scale_dtype),
+ ]
+ ref_state_dict = dict(zip(self.ref_weights_alias(), tensor_list))
+ self.init_reference_weights(ref_state_dict)
+ self.init_tilert_weights(
+ {_k: _v[self.device_id] for _k, _v in self.device_sharding(ref_state_dict).items()}
+ )
+
+ def golden_forward(
+ self,
+ x_quant: torch.Tensor,
+ x_scale: torch.Tensor,
+ cur_pos: int = 0, # noqa: U100
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Pure PyTorch reference: dequant FP8 -> matmul -> q, kv, pe."""
+ assert self.ref_wq_a is not None
+ assert self.ref_wkv_a is not None
+ assert self.ref_w_pe is not None
+
+ if self.algorithm == ProjXWqkvaAlgorithm.FP16MMA:
+ x_float = x_quant.float()
+ else:
+ x_fp8 = x_quant.to(torch.float32)
+ scale_expanded = x_scale.unsqueeze(-1).repeat(1, 1, 1, self.block_size)
+ scale_expanded = scale_expanded.reshape(x_quant.shape)
+ x_float = x_fp8 * scale_expanded
+
+ q_out = torch.matmul(x_float, self.ref_wq_a.transpose(0, 1).float())
+ kv_out = torch.matmul(x_float, self.ref_wkv_a.transpose(0, 1).float())
+ pe_out = torch.matmul(x_float, self.ref_w_pe.transpose(0, 1).float())
+ return (
+ q_out.to(torch.bfloat16),
+ kv_out.to(torch.bfloat16),
+ pe_out.to(torch.bfloat16),
+ )
+
+ def tilert_forward(
+ self,
+ x_quant: torch.Tensor,
+ x_scale: torch.Tensor,
+ cur_pos: int = 0,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Run FP8 QMMA GEMV via TileRT CUDA kernel."""
+ assert self.cur_pos is not None
+ assert self.pe_cache_out is not None
+ self.cur_pos.fill_(cur_pos)
+ projx_wqkva(
+ x_quant,
+ x_scale,
+ self.tilert_wqkva,
+ self.cur_pos,
+ self.q_out,
+ self.kv_out,
+ self.pe_cache_out,
+ self.profile_logs,
+ self.compute_kernel_type,
+ model_arch=self.model_args.arch_name,
+ )
+
+ seq_len = x_quant.size(-2)
+ pe_at_pos = self.pe_cache_out[:, cur_pos : cur_pos + seq_len, :]
+ return self.q_out, self.kv_out, pe_at_pos
+
+ def __call__(
+ self,
+ x_quant: torch.Tensor,
+ x_scale: torch.Tensor,
+ cur_pos: int = 0,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ return self.golden_forward(x_quant, x_scale, cur_pos)
diff --git a/tilert/models/glm_5/_dsa_v32/ops/qkv_rope.py b/tilert/models/glm_5/_dsa_v32/ops/qkv_rope.py
new file mode 100644
index 0000000..7f16a1c
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/ops/qkv_rope.py
@@ -0,0 +1,192 @@
+"""QKV Rope operation module."""
+
+from dataclasses import dataclass
+from enum import Enum
+
+import torch
+
+from tilert.models.base import TileRTModule
+from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
+from tilert.models.utils import apply_rotary_emb
+from tilert.utils import get_profile_log_tensor
+
+__all__ = [
+ "qkv_rope",
+ "QKVRoPE",
+ "QKVRoPERefWeightsAlias",
+ "QKVRoPETilertWeightsAlias",
+]
+
+
+def qkv_rope(
+ pe_cache: torch.Tensor,
+ kv_cache: torch.Tensor,
+ rope_freqs: torch.Tensor,
+ cur_pos: torch.Tensor,
+ profile_logs: torch.Tensor,
+ model_arch: str,
+ compute_kernel_type: str = "general",
+) -> None:
+ """
+ Perform QKV Rope operation.
+
+ Args:
+ pe_cache: Q PE tensor (bsz, seq, n_local_heads, qk_rope_head_dim).
+ kv_cache: K PE cache (bsz, seq, qk_rope_head_dim).
+ rope_freqs: Rope frequencies tensor.
+ cur_pos: Current position tensor.
+ profile_logs: Profile logs tensor.
+ model_arch: Model architecture string.
+ compute_kernel_type: Compute kernel type string.
+ """
+ torch.ops.tilert.qkv_rope_op(
+ pe_cache,
+ kv_cache,
+ rope_freqs,
+ cur_pos,
+ model_arch,
+ compute_kernel_type,
+ profile_logs,
+ )
+
+
+@dataclass
+class QKVRoPERefWeightsAlias:
+ """Reference weights alias for QKVRoPE (no weights)."""
+
+ @property
+ def ref_tensor_alias(self) -> list[str]:
+ return []
+
+ def __call__(self) -> list[str]:
+ return self.ref_tensor_alias
+
+
+@dataclass
+class QKVRoPETilertWeightsAlias:
+ """TileRT weights alias for QKVRoPE (no weights)."""
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return []
+
+ def __call__(self) -> list[str]:
+ return self.tilert_tensor_alias
+
+
+class QKVRoPEAlgorithm(Enum):
+ """QKVRoPE algorithm."""
+
+ GENERAL = "general"
+
+
+class QKVRoPE(TileRTModule):
+ """QKV RoPE module. Unified for deepseek_v3_2 and glm_5."""
+
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [QKVRoPEAlgorithm.GENERAL],
+ "glm_5": [QKVRoPEAlgorithm.GENERAL],
+ }
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ num_devices: int = 1,
+ device_id: int = 0,
+ layer_idx: int = 0,
+ ref_weights_alias: QKVRoPERefWeightsAlias | None = None,
+ ) -> None:
+ super().__init__(
+ self.__class__.__name__,
+ model_args=model_args,
+ num_devices=num_devices,
+ device_id=device_id,
+ layer_idx=layer_idx,
+ )
+ self.tilert_weights_alias = QKVRoPETilertWeightsAlias()
+ self.ref_weights_alias = (
+ ref_weights_alias if ref_weights_alias is not None else QKVRoPERefWeightsAlias()
+ )
+ self.n_local_heads = model_args.n_heads // num_devices
+ self.qk_rope_head_dim = model_args.qk_rope_head_dim
+ self.profile_logs: torch.Tensor | None = None
+
+ def get_weights_list(self) -> list[torch.Tensor]:
+ return []
+
+ def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ del weights_map
+ return {}
+
+ def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ del state_dict
+ pass
+
+ def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ del state_dict
+ pass
+
+ def init_random_weights(self) -> None:
+ pass
+
+ def init_tilert_vars(self, batch_size: int, seq_len: int) -> None:
+ del batch_size, seq_len
+ self.profile_logs = get_profile_log_tensor()
+ self.is_var_init = True
+
+ def golden_forward(
+ self,
+ q_pe: torch.Tensor,
+ pe_cache: torch.Tensor,
+ start_pos: int,
+ freqs_cis: torch.Tensor,
+ bsz: int,
+ seqlen: int,
+ ) -> torch.Tensor:
+ end_pos = start_pos + seqlen
+
+ k_pe = pe_cache[:bsz, start_pos:end_pos]
+ k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
+ pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
+
+ return apply_rotary_emb(q_pe, freqs_cis)
+
+ def tilert_forward(
+ self,
+ q_pe: torch.Tensor,
+ pe_cache: torch.Tensor,
+ start_pos: int,
+ freqs_cis: torch.Tensor,
+ bsz: int,
+ seqlen: int,
+ ) -> torch.Tensor:
+ assert self.profile_logs is not None
+ end_pos = start_pos + seqlen
+
+ q_pe_rope = q_pe.clone()
+ rope_freqs = torch.view_as_real(freqs_cis).reshape(*freqs_cis.shape[:-1], -1)
+ cur_pos = torch.tensor([start_pos], dtype=torch.int32)
+
+ qkv_rope(
+ q_pe_rope,
+ pe_cache[:bsz, start_pos:end_pos],
+ rope_freqs,
+ cur_pos,
+ self.profile_logs,
+ model_arch=self.model_args.arch_name,
+ )
+
+ return q_pe_rope
+
+ def __call__(
+ self,
+ q_pe: torch.Tensor,
+ pe_cache: torch.Tensor,
+ start_pos: int,
+ freqs_cis: torch.Tensor,
+ bsz: int,
+ seqlen: int,
+ ) -> torch.Tensor:
+ if self.flag_enable_tilert:
+ return self.tilert_forward(q_pe, pe_cache, start_pos, freqs_cis, bsz, seqlen)
+ return self.golden_forward(q_pe, pe_cache, start_pos, freqs_cis, bsz, seqlen)
diff --git a/tilert/models/glm_5/_dsa_v32/ops/receive_selected_token_ids.py b/tilert/models/glm_5/_dsa_v32/ops/receive_selected_token_ids.py
new file mode 100644
index 0000000..508d13e
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/ops/receive_selected_token_ids.py
@@ -0,0 +1,35 @@
+"""ReceiveSelectedTokenIds — receive idx_selects from GPU 0."""
+
+import torch
+
+__all__ = [
+ "receive_selected_token_ids",
+]
+
+
+def receive_selected_token_ids(
+ ll_buf: torch.Tensor,
+ dst: torch.Tensor,
+ expected_flag: int,
+ profile_logs: torch.Tensor,
+ model_arch: str,
+ compute_kernel_type: str = "bf16",
+) -> None:
+ """Receive idx_selects from GPU 0.
+
+ Args:
+ ll_buf: Receive buffer on this GPU (written by GPU 0).
+ dst: Destination idx_selects tensor [1, S, 2048] int32.
+ expected_flag: Expected synchronization flag value.
+ profile_logs: Profile logs tensor.
+ model_arch: Model architecture ("deepseek_v3_2" or "glm_5").
+ compute_kernel_type: Compute kernel type ("bf16").
+ """
+ torch.ops.tilert.receive_selected_token_ids_op(
+ ll_buf,
+ dst,
+ expected_flag,
+ model_arch,
+ compute_kernel_type,
+ profile_logs,
+ )
diff --git a/tilert/models/glm_5/_dsa_v32/ops/rmsnorm_expert_proj.py b/tilert/models/glm_5/_dsa_v32/ops/rmsnorm_expert_proj.py
new file mode 100644
index 0000000..ae004fa
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/ops/rmsnorm_expert_proj.py
@@ -0,0 +1,172 @@
+"""RMSNormExpertProj operation module."""
+
+from dataclasses import dataclass
+from enum import Enum
+
+import torch
+from torch import nn
+
+from tilert.models.base import TileRTModule
+from tilert.models.common import RMSNorm, init_func, linear
+from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
+from tilert.utils import get_profile_log_tensor
+
+__all__ = [
+ "RMSNormExpertProj",
+ "RMSNormExpertProjRefWeightsAlias",
+ "RMSNormExpertProjTilertWeightsAlias",
+]
+
+
+@dataclass
+class RMSNormExpertProjRefWeightsAlias:
+ """Reference weights alias for RMSNormExpertProj."""
+
+ post_attention_layernorm_weight = "post_attention_layernorm.weight"
+ mlp_gate_weight = "mlp.gate.weight"
+
+ def __call__(self) -> list[str]:
+ return [self.post_attention_layernorm_weight, self.mlp_gate_weight]
+
+
+@dataclass
+class RMSNormExpertProjTilertWeightsAlias:
+ """TileRT weights alias for RMSNormExpertProj."""
+
+ unproj_o_gamma = "unproj_o_gamma"
+ exp_proj_weights = "exp_proj_weights"
+
+ def __call__(self) -> list[str]:
+ return [self.unproj_o_gamma, self.exp_proj_weights]
+
+
+class RMSNormExpertProjAlgorithm(Enum):
+ """RMSNormExpertProj algorithm."""
+
+ GENERAL = "general"
+
+
+class RMSNormExpertProj(TileRTModule):
+ """RMS Norm followed by expert projection."""
+
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [RMSNormExpertProjAlgorithm.GENERAL],
+ "glm_5": [RMSNormExpertProjAlgorithm.GENERAL],
+ }
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ num_devices: int,
+ device_id: int = 0,
+ ref_weights_alias: RMSNormExpertProjRefWeightsAlias | None = None,
+ tilert_weights_alias: RMSNormExpertProjTilertWeightsAlias | None = None,
+ ):
+ super().__init__(
+ type(self).__name__,
+ model_args=model_args,
+ num_devices=num_devices,
+ device_id=device_id,
+ )
+ self.dim = model_args.dim
+ self.eps = model_args.eps
+
+ self.ref_weights_alias = (
+ ref_weights_alias
+ if ref_weights_alias is not None
+ else RMSNormExpertProjRefWeightsAlias()
+ )
+ self.tilert_weights_alias = (
+ tilert_weights_alias
+ if tilert_weights_alias is not None
+ else RMSNormExpertProjTilertWeightsAlias()
+ )
+
+ self.is_ref_weights_init = False
+ self.is_tilert_weights_init = False
+
+ self.ref_rmsnorm: RMSNorm | None = None
+ self.ref_proj_weight: torch.Tensor | None = None
+ self.proj_weight = nn.Parameter(
+ init_func(torch.empty(model_args.n_routed_experts, model_args.dim))
+ )
+ self.n_routed_experts = model_args.n_routed_experts
+
+ self.tilert_proj_weight: torch.Tensor | None = None
+ self.tilert_rms_norm_weight: torch.Tensor | None = None
+
+ self.profile_logs = get_profile_log_tensor()
+
+ def get_weights_list(self) -> list[torch.Tensor]:
+ return [self.tilert_rms_norm_weight, self.tilert_proj_weight]
+
+ def device_sharding(
+ self, rms_norm_weight: torch.Tensor, proj_weight: torch.Tensor
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ return rms_norm_weight.float().contiguous(), proj_weight.contiguous()
+
+ def init_reference_weights(
+ self, state_dict: dict[str, torch.Tensor], device_id: int | None = None
+ ) -> None:
+ del device_id
+ self.ref_rmsnorm = RMSNorm(self.dim, self.eps)
+ self.ref_rmsnorm.weight.data = state_dict[
+ self.ref_weights_alias.post_attention_layernorm_weight
+ ]
+ self.ref_proj_weight = state_dict[self.ref_weights_alias.mlp_gate_weight]
+ self.is_ref_weights_init = True
+
+ def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ self.tilert_proj_weight = (
+ state_dict[self.tilert_weights_alias.exp_proj_weights].detach().clone()
+ )
+ self.tilert_rms_norm_weight = (
+ state_dict[self.tilert_weights_alias.unproj_o_gamma].detach().clone()
+ )
+ self.is_tilert_weights_init = True
+
+ def init_random_weights(self) -> None:
+ proj_weight = torch.randn(self.n_routed_experts, self.dim)
+ rms_norm_weight = torch.randn(self.dim, dtype=torch.float32)
+ ref_state_dict = dict(
+ zip(
+ self.ref_weights_alias(),
+ [rms_norm_weight, proj_weight],
+ )
+ )
+ self.init_reference_weights(ref_state_dict)
+ assert self.ref_rmsnorm is not None and self.ref_proj_weight is not None
+ sharded_weights = self.device_sharding(self.ref_rmsnorm.weight, self.ref_proj_weight)
+ self.init_tilert_weights(dict(zip(self.tilert_weights_alias(), sharded_weights)))
+
+ def golden_forward(
+ self, x_in: torch.Tensor, residual: torch.Tensor | None = None
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ assert self.is_ref_weights_init, "Reference weights must be initialized before forward pass"
+ assert self.ref_rmsnorm is not None and self.ref_proj_weight is not None
+ norm_x = self.ref_rmsnorm(x_in, residual)
+ scores = linear(norm_x.view(-1, self.dim).float(), self.ref_proj_weight.float())
+ return norm_x, scores
+
+ def tilert_forward(self, x_in: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ assert self.is_tilert_weights_init, "Tilert weights must be initialized before forward pass"
+ assert self.tilert_rms_norm_weight is not None and self.tilert_proj_weight is not None
+ x_in = x_in.to(torch.bfloat16)
+ hidden_out = torch.zeros_like(x_in)
+ scores_out = torch.zeros(
+ (x_in.shape[0], x_in.shape[1], self.n_routed_experts), dtype=torch.float32
+ )
+ torch.ops.tilert.rmsnorm_expert_proj_op(
+ x_in,
+ self.tilert_rms_norm_weight,
+ self.tilert_proj_weight,
+ scores_out,
+ hidden_out,
+ self.model_args.arch_name,
+ "bf16",
+ self.profile_logs,
+ )
+ return hidden_out, scores_out
+
+ def __call__(self, x_in: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ return self.tilert_forward(x_in)
diff --git a/python/models/deepseek_v3_2/ops/rmsnorm_head_proj.py b/tilert/models/glm_5/_dsa_v32/ops/rmsnorm_head_proj.py
similarity index 84%
rename from python/models/deepseek_v3_2/ops/rmsnorm_head_proj.py
rename to tilert/models/glm_5/_dsa_v32/ops/rmsnorm_head_proj.py
index 6145b5b..fa2086d 100644
--- a/python/models/deepseek_v3_2/ops/rmsnorm_head_proj.py
+++ b/tilert/models/glm_5/_dsa_v32/ops/rmsnorm_head_proj.py
@@ -1,74 +1,40 @@
"""RMSNormHeadProj operation module."""
-from collections.abc import Callable
from dataclasses import dataclass
from enum import Enum
import torch
from tilert.models.base import TileRTModule, TilertWeightsConverter
-from tilert.models.deepseek_v3_2.model_args import ModelArgs
+from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
from tilert.utils import get_profile_log_tensor
__all__ = [
"rmsnorm_head_proj",
- "rmsnorm_head_proj_glm5",
"RMSNormHeadProj",
"RMSNormHeadProjTilertWeightsAlias",
]
def rmsnorm_head_proj(
- hidden_in: torch.Tensor,
- gamma_in: torch.Tensor,
- weight_in: torch.Tensor,
- logits_out: torch.Tensor,
- profile_logs: torch.Tensor,
-) -> None:
- """RMS Norm Head Projection operation."""
- torch.ops.tilert.rmsnorm_head_proj_op(
- hidden_in,
- gamma_in,
- weight_in,
- logits_out,
- profile_logs,
- )
-
-
-def rmsnorm_head_proj_dsv32(
hidden_in: torch.Tensor,
gamma_in: torch.Tensor,
weight_in: torch.Tensor,
hidden_rmsnorm_out: torch.Tensor,
logits_out: torch.Tensor,
profile_logs: torch.Tensor,
+ model_arch: str,
+ compute_kernel_type: str = "general",
) -> None:
"""RMS Norm Head Projection operation."""
- del hidden_rmsnorm_out
torch.ops.tilert.rmsnorm_head_proj_op(
- hidden_in,
- gamma_in,
- weight_in,
- logits_out,
- profile_logs,
- )
-
-
-def rmsnorm_head_proj_glm5(
- hidden_in: torch.Tensor,
- gamma_in: torch.Tensor,
- weight_in: torch.Tensor,
- hidden_rmsnorm_out: torch.Tensor,
- logits_out: torch.Tensor,
- profile_logs: torch.Tensor,
-) -> None:
- """RMS Norm Head Projection operation."""
- torch.ops.tilert.rmsnorm_head_proj_glm5_op(
hidden_in,
gamma_in,
weight_in,
hidden_rmsnorm_out,
logits_out,
+ model_arch,
+ compute_kernel_type,
profile_logs,
)
@@ -135,6 +101,11 @@ def __call__(self) -> list[str]:
class RMSNormHeadProj(TileRTModule):
"""RMSNormHeadProj module"""
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [RMSNormHeadProjAlgorithm.GENERAL],
+ "glm_5": [RMSNormHeadProjAlgorithm.GENERAL],
+ }
+
def __init__(
self,
model_args: ModelArgs,
@@ -155,34 +126,20 @@ def __init__(
self.algorithm = algorithm
self.eps = self.model_args.eps
- # reference weights
self.ref_rmsnorm_gamma: torch.Tensor | None = None
self.ref_head_proj: torch.Tensor | None = None
- # tilert weights
self.tilert_rmsnorm_gamma: torch.Tensor | None = None
self.tilert_head_proj: torch.Tensor | None = None
- # tilert vars
self.hidden_rmsnorm_out: torch.Tensor | None = None
self.hidden_out: torch.Tensor | None = None
self.profile_logs: torch.Tensor | None = None
self.is_init = False
- # tilert_funcs
- self.rmsnorm_head_proj_func: Callable | None = None
-
- if self.arch_name == "deepseek_v3_2":
- self.rmsnorm_head_proj_func = rmsnorm_head_proj_dsv32
- elif self.arch_name == "glm_5":
- self.rmsnorm_head_proj_func = rmsnorm_head_proj_glm5
- else:
- raise ValueError(f"Unsupported architecture: {self.arch_name}")
-
self.tilert_weights_alias = RMSNormHeadProjTilertWeightsAlias()
- # reference tensor aliases
self.ref_tensor_alias: list[str] = [
"model.norm.weight",
"lm_head.weight",
@@ -217,7 +174,6 @@ def device_sharding(
rmsnorm_gamma_key = "model.norm.weight"
head_proj_key = "lm_head.weight"
rmsnorm_gamma = weights_dict[rmsnorm_gamma_key][None, ...]
- # repeat number of devices times
rmsnorm_gamma = rmsnorm_gamma.repeat(self.num_devices, 1)
head_proj = weights_dict[head_proj_key]
@@ -258,7 +214,6 @@ def init_tilert_vars(self, batch_size: int, seq_len: int) -> None:
batch_size: Batch size.
seq_len: Sequence length.
"""
- # tilert vars
self.hidden_rmsnorm_out = torch.zeros(
(batch_size, seq_len, self.dim),
dtype=torch.bfloat16,
@@ -319,16 +274,16 @@ def tilert_forward(
self,
hidden_in: torch.Tensor,
) -> torch.Tensor:
- assert self.rmsnorm_head_proj_func is not None
assert self.hidden_out is not None
- self.rmsnorm_head_proj_func(
+ rmsnorm_head_proj(
hidden_in,
self.tilert_rmsnorm_gamma,
self.tilert_head_proj,
self.hidden_rmsnorm_out,
self.hidden_out,
self.profile_logs,
+ model_arch=self.model_args.arch_name,
)
return self.hidden_out
diff --git a/tilert/models/glm_5/_dsa_v32/ops/rmsnorm_kv.py b/tilert/models/glm_5/_dsa_v32/ops/rmsnorm_kv.py
new file mode 100644
index 0000000..81d161c
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/ops/rmsnorm_kv.py
@@ -0,0 +1,204 @@
+"""RMSNormKV operation module."""
+
+from dataclasses import dataclass
+from enum import Enum
+
+import torch
+
+from tilert.models.base import TileRTModule
+from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
+from tilert.utils import get_profile_log_tensor
+
+__all__ = [
+ "rmsnorm_kv",
+ "KVRMSNorm",
+ "KVRMSNormRefWeightsAlias",
+ "KVRMSNormTilertWeightsAlias",
+]
+
+
+def rmsnorm_kv(
+ kv: torch.Tensor,
+ gamma: torch.Tensor,
+ cur_pos: torch.Tensor,
+ kv_cache: torch.Tensor,
+ profile_logs: torch.Tensor,
+ model_arch: str,
+ compute_kernel_type: str = "general",
+) -> None:
+ """
+ Define the RMSNormKV operation.
+
+ Args:
+ kv: Input tensor.
+ gamma: Weight tensor.
+ cur_pos: Current position tensor.
+ kv_cache: Output tensor.
+ profile_logs: Profile logs tensor.
+ model_arch: Model architecture string.
+ compute_kernel_type: Compute kernel type string.
+ """
+ torch.ops.tilert.rmsnorm_kv_op(
+ kv, gamma, cur_pos, kv_cache, model_arch, compute_kernel_type, profile_logs
+ )
+
+
+@dataclass
+class KVRMSNormRefWeightsAlias:
+ """Reference weights alias for KVRMSNorm."""
+
+ kv_norm_weight = "self_attn.kv_a_layernorm.weight"
+
+ @property
+ def ref_tensor_alias(self) -> list[str]:
+ return [self.kv_norm_weight]
+
+ def __call__(self) -> list[str]:
+ return self.ref_tensor_alias
+
+
+@dataclass
+class KVRMSNormTilertWeightsAlias:
+ """TileRT weights alias for KVRMSNorm."""
+
+ kv_norm_gamma = "kv_rmsnorm_gamma"
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return [self.kv_norm_gamma]
+
+ def __call__(self) -> list[str]:
+ return self.tilert_tensor_alias
+
+
+class KVRMSNormAlgorithm(Enum):
+ """KVRMSNorm algorithm."""
+
+ GENERAL = "general"
+
+
+class KVRMSNorm(TileRTModule):
+ """KVRMSNorm module: RMSNorm on KV tensor with in-place write to kv_cache."""
+
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [KVRMSNormAlgorithm.GENERAL],
+ "glm_5": [KVRMSNormAlgorithm.GENERAL],
+ }
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ num_devices: int,
+ device_id: int,
+ ref_weights_alias: KVRMSNormRefWeightsAlias | None = None,
+ tilert_weights_alias: KVRMSNormTilertWeightsAlias | None = None,
+ layer_idx: int = 0,
+ golden_weights_dir: str = "",
+ tilert_weights_dir: str = "",
+ ):
+ super().__init__(
+ self.__class__.__name__,
+ model_args=model_args,
+ num_devices=num_devices,
+ device_id=device_id,
+ layer_idx=layer_idx,
+ golden_weights_dir=golden_weights_dir,
+ tilert_weights_dir=tilert_weights_dir,
+ )
+
+ self.tilert_weights_alias = (
+ tilert_weights_alias
+ if tilert_weights_alias is not None
+ else KVRMSNormTilertWeightsAlias()
+ )
+ self.ref_weights_alias = (
+ ref_weights_alias if ref_weights_alias is not None else KVRMSNormRefWeightsAlias()
+ )
+
+ self.kv_lora_rank = self.model_args.kv_lora_rank
+ self.eps = self.model_args.eps
+
+ self.ref_norm_gamma: torch.Tensor | None = None
+ self.tilert_kv_norm_weight: torch.Tensor | None = None
+ self.profile_logs: torch.Tensor | None = None
+
+ def get_weights_list(self) -> list[torch.Tensor]:
+ return [self.tilert_kv_norm_weight]
+
+ def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ """
+ Device sharding: replicate gamma for each device.
+
+ Args:
+ weights_map: Map from ref weight alias to tensor.
+
+ Returns:
+ Map from tilert weight alias to (num_devices, ...) tensors.
+ """
+ gamma = weights_map[self.ref_weights_alias.kv_norm_weight][None, ...].repeat(
+ self.num_devices, 1
+ )
+ return {self.tilert_weights_alias.kv_norm_gamma: gamma}
+
+ def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ """Initialize reference weights from state dict."""
+ self.ref_norm_gamma = state_dict[self.ref_weights_alias.kv_norm_weight].contiguous()
+ assert (
+ self.ref_norm_gamma.shape[-1] == self.kv_lora_rank
+ ), f"kv_norm weight shape must be ({self.kv_lora_rank},), got {self.ref_norm_gamma.shape}"
+
+ def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ """Initialize TileRT weights from state dict."""
+ gamma = state_dict[self.tilert_weights_alias.kv_norm_gamma]
+ self.tilert_kv_norm_weight = gamma.float().detach().clone().contiguous()
+
+ def init_tilert_vars(self, batch_size: int, seq_len: int) -> None:
+ """Allocate TileRT profiling buffer."""
+ del batch_size, seq_len
+ self.profile_logs = get_profile_log_tensor()
+ self.is_var_init = True
+
+ def init_random_weights(self) -> None:
+ """Initialize random reference and TileRT weights for testing."""
+ ref_state_dict = {
+ self.ref_weights_alias.kv_norm_weight: torch.randn(
+ self.kv_lora_rank, dtype=torch.float32
+ ),
+ }
+ self.init_reference_weights(ref_state_dict)
+ sharded = self.device_sharding(ref_state_dict)
+ self.init_tilert_weights({k: v[self.device_id] for k, v in sharded.items()})
+
+ def golden_forward(
+ self, kv: torch.Tensor, kv_cache: torch.Tensor, start_pos: int, bsz: int, seqlen: int
+ ) -> None:
+ """Reference forward: RMSNorm and write to kv_cache."""
+ assert self.ref_norm_gamma is not None
+ end_pos = start_pos + seqlen
+ out = torch.nn.functional.rms_norm(
+ kv.float(), [kv.size(-1)], self.ref_norm_gamma, self.eps
+ ).to(kv.dtype)
+ kv_cache[:bsz, start_pos:end_pos].copy_(out)
+
+ def tilert_forward(
+ self, kv: torch.Tensor, kv_cache: torch.Tensor, start_pos: int, bsz: int, seqlen: int
+ ) -> None:
+ del seqlen
+ assert self.tilert_kv_norm_weight is not None
+ assert self.profile_logs is not None
+ cur_pos = torch.tensor([start_pos], dtype=torch.int32, device=kv.device)
+ rmsnorm_kv(
+ kv,
+ self.tilert_kv_norm_weight,
+ cur_pos,
+ kv_cache[:bsz],
+ self.profile_logs,
+ model_arch=self.model_args.arch_name,
+ )
+
+ def __call__(
+ self, kv: torch.Tensor, kv_cache: torch.Tensor, start_pos: int, bsz: int, seqlen: int
+ ) -> None:
+ if self.flag_enable_tilert:
+ return self.tilert_forward(kv, kv_cache, start_pos, bsz, seqlen)
+ return self.golden_forward(kv, kv_cache, start_pos, bsz, seqlen)
diff --git a/tilert/models/glm_5/_dsa_v32/ops/rmsnorm_projq_wqb.py b/tilert/models/glm_5/_dsa_v32/ops/rmsnorm_projq_wqb.py
new file mode 100644
index 0000000..92d7a99
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/ops/rmsnorm_projq_wqb.py
@@ -0,0 +1,530 @@
+"""RmsnormProjqWqb operation module."""
+
+import math
+from dataclasses import dataclass
+from enum import Enum
+
+import torch
+
+from tilert.models.base import TileRTModule, TilertWeightsConverter
+from tilert.models.common import weight_dequant
+from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
+from tilert.utils import get_profile_log_tensor
+
+__all__ = [
+ "RmsnormProjqWqb",
+ "RmsnormProjqWqbAlgorithm",
+ "RmsnormProjqWqbWeightsConverter",
+]
+
+
+def rmsnorm_projq_wqb_op(
+ q: torch.Tensor,
+ wq_b: torch.Tensor,
+ wq_b_scales: torch.Tensor,
+ q_norm_weight: torch.Tensor,
+ q_nope: torch.Tensor,
+ q_pe: torch.Tensor,
+ profile_logs: torch.Tensor,
+ algorithm: str,
+ model_arch: str,
+) -> None:
+ torch.ops.tilert.rmsnorm_proj_qb_op(
+ q,
+ wq_b,
+ wq_b_scales,
+ q_norm_weight,
+ q_nope,
+ q_pe,
+ model_arch,
+ algorithm,
+ profile_logs,
+ torch.empty(0, dtype=torch.int64, device=q.device),
+ )
+
+
+class RmsnormProjqWqbAlgorithm(Enum):
+ """RmsnormProjqWqb algorithm."""
+
+ FP16MMA = "fp16mma"
+
+
+class RmsnormProjqWqbWeightsConverter(TilertWeightsConverter):
+ """Weights converter for RmsnormProjqWqb.
+
+ Supports configurations where n_heads is not evenly divisible by
+ num_devices; in that case n_local_heads is padded and padded head
+ weight rows are zero-filled.
+ """
+
+ kBf16NumCtas = 80
+ kGemvPageSize = 8
+
+ def __init__(self, model_args: ModelArgs, num_devices: int):
+ super().__init__(model_args=model_args, num_devices=num_devices)
+
+ self.proc_groups = 8
+ self.repeat = 16
+
+ self.block_size = self.model_args.block_size
+
+ self.qk_nope_head_dim = self.model_args.qk_nope_head_dim
+ self.qk_rope_head_dim = self.model_args.qk_rope_head_dim
+ self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
+
+ self.needs_padding = self.model_args.n_heads % num_devices != 0
+ self.n_local_heads = self._compute_n_local_heads(
+ self.model_args.n_heads, num_devices, self.qk_head_dim
+ )
+
+ self.q_lora_dim = self.model_args.q_lora_rank
+ self.q_lora_qdim = self.q_lora_dim // self.block_size
+
+ self.qk_dim = self.qk_head_dim * self.n_local_heads
+ self.qk_qdim = self.qk_dim // self.block_size
+
+ assert self.qk_dim % (self.kBf16NumCtas * self.kGemvPageSize) == 0, (
+ f"qk_dim ({self.qk_dim}) must be divisible by "
+ f"kBf16NumCtas * kGemvPageSize ({self.kBf16NumCtas * self.kGemvPageSize})"
+ )
+ assert self.qk_dim % self.block_size == 0, (
+ f"qk_dim ({self.qk_dim}) must be divisible by block_size ({self.block_size}) "
+ f"for scale alignment"
+ )
+
+ @classmethod
+ def _compute_n_local_heads(cls, n_total_heads: int, num_devices: int, qk_head_dim: int) -> int:
+ """Compute padded n_local_heads per device."""
+ if n_total_heads % num_devices == 0:
+ return n_total_heads // num_devices
+
+ base = math.ceil(n_total_heads / num_devices)
+ align_unit = cls.kBf16NumCtas * cls.kGemvPageSize
+ g = math.gcd(qk_head_dim, align_unit)
+ head_align = align_unit // g
+ return math.ceil(base / head_align) * head_align
+
+ @staticmethod
+ def _redistribute_heads(
+ wq_b_full: torch.Tensor,
+ wq_b_scale_full: torch.Tensor,
+ n_total_heads: int,
+ n_local_heads: int,
+ num_devices: int,
+ qk_head_dim: int,
+ block_size: int,
+ ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
+ """Redistribute heads across devices with padding.
+
+ Args:
+ wq_b_full: [n_total_heads * qk_head_dim, q_lora_dim] full weight.
+ wq_b_scale_full: [n_total_heads * qk_head_dim // block_size, q_lora_qdim] full scale.
+ n_total_heads: Total number of heads (e.g. 128).
+ n_local_heads: Target heads per GPU (padded, e.g. 20).
+ num_devices: Number of devices (e.g. 7).
+ qk_head_dim: Head dimension (e.g. 192).
+ block_size: Quantization block size (e.g. 128).
+
+ Returns:
+ Lists of per-device (wq_b, wq_b_scale) with shape
+ [n_local_heads * qk_head_dim, q_lora_dim] and
+ [n_local_heads * qk_head_dim // block_size, q_lora_qdim].
+ """
+ total_rows = n_total_heads * qk_head_dim
+ rows_per_dev = n_local_heads * qk_head_dim
+ scale_rows_per_dev = rows_per_dev // block_size
+ total_scale_rows = total_rows // block_size
+
+ q_lora_dim = wq_b_full.shape[-1]
+ q_lora_qdim = wq_b_scale_full.shape[-1]
+
+ assert rows_per_dev % block_size == 0, (
+ f"n_local_heads * qk_head_dim ({rows_per_dev}) must be "
+ f"divisible by block_size ({block_size})"
+ )
+
+ wq_b_list = []
+ scale_list = []
+ for dev in range(num_devices):
+ start_row = dev * rows_per_dev
+ end_row = min(total_rows, start_row + rows_per_dev)
+ real_rows = max(0, end_row - start_row)
+
+ dev_wqb = torch.zeros(
+ rows_per_dev, q_lora_dim, dtype=wq_b_full.dtype, device=wq_b_full.device
+ )
+ if real_rows > 0:
+ dev_wqb[:real_rows] = wq_b_full[start_row:end_row]
+
+ start_scale = dev * scale_rows_per_dev
+ end_scale = min(total_scale_rows, start_scale + scale_rows_per_dev)
+ real_scale_rows = max(0, end_scale - start_scale)
+
+ dev_scale = torch.zeros(
+ scale_rows_per_dev,
+ q_lora_qdim,
+ dtype=wq_b_scale_full.dtype,
+ device=wq_b_scale_full.device,
+ )
+ if real_scale_rows > 0:
+ dev_scale[:real_scale_rows] = wq_b_scale_full[start_scale:end_scale]
+
+ wq_b_list.append(dev_wqb)
+ scale_list.append(dev_scale)
+
+ return wq_b_list, scale_list
+
+ @staticmethod
+ def _swizzle_mma_16x16(mat_in: torch.Tensor) -> torch.Tensor:
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 16
+ pre_shape = mat_in.shape[:-2]
+ mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 2).transpose(-4, -3).transpose(-5, -4)
+ return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 2).transpose(-3, -2)
+
+ @staticmethod
+ def _swizzle_mma_16x16_for_pages(
+ mat_in: torch.Tensor, q_lora_dim: int, pages: int
+ ) -> torch.Tensor:
+ """Swizzle a 16xK matrix for the paged weight layout (K divisible by 16)."""
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == q_lora_dim
+ k_per_page = q_lora_dim // pages
+ n_k_tiles = k_per_page // 16
+ pre_shape = mat_in.shape[:-2]
+ mat_in = mat_in.reshape(*pre_shape, 16, pages, k_per_page).transpose(-3, -2)
+ mat_in = mat_in.reshape(*pre_shape, pages, 16, n_k_tiles, 16).transpose(-3, -2)
+ mat_in = RmsnormProjqWqbWeightsConverter._swizzle_mma_16x16(mat_in)
+ return mat_in.contiguous()
+
+ def _common_to_tilert_fp16mma(
+ self,
+ wq_b: torch.Tensor,
+ wq_b_scales_raw: torch.Tensor,
+ rmsnorm_gamma: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Convert common weights to the packed TileRT FP16 layout."""
+ pages = 2
+ rows_per_cta = 32
+
+ qk_nope_dim = self.n_local_heads * self.qk_nope_head_dim
+ qk_pe_dim = self.n_local_heads * self.qk_rope_head_dim
+ nope_ctas = qk_nope_dim // rows_per_cta
+ pe_ctas = qk_pe_dim // rows_per_cta
+ num_ctas = nope_ctas + pe_ctas
+
+ wq_b_scales_f32 = wq_b_scales_raw.to(torch.float32)
+ wq_b_scales_f32 = (
+ wq_b_scales_f32.reshape(self.qk_qdim, 1, self.q_lora_qdim)
+ .repeat(1, self.block_size, 1)
+ .reshape(self.qk_dim, self.q_lora_qdim)
+ )
+
+ wq_b_scales_f32 = wq_b_scales_f32.reshape(
+ self.n_local_heads, self.qk_head_dim, self.q_lora_qdim
+ )
+ scale_nope = wq_b_scales_f32[:, : self.qk_nope_head_dim, :].reshape(-1, self.q_lora_qdim)
+ scale_pe = wq_b_scales_f32[:, self.qk_nope_head_dim :, :].reshape(-1, self.q_lora_qdim)
+
+ scale_nope = scale_nope.reshape(
+ nope_ctas, rows_per_cta, pages, self.q_lora_qdim // pages
+ ).transpose(1, 2)[:, :, 0, :]
+ scale_pe = scale_pe.reshape(
+ pe_ctas, rows_per_cta, pages, self.q_lora_qdim // pages
+ ).transpose(1, 2)[:, :, 0, :]
+
+ scales = torch.cat([scale_nope, scale_pe], dim=0)
+ scales_fp8 = scales.contiguous().view(torch.float8_e4m3fn)
+
+ wq_b = wq_b.reshape(self.n_local_heads, self.qk_head_dim, self.q_lora_dim)
+ wq_b_nope = wq_b[:, : self.qk_nope_head_dim, :].reshape(-1, self.q_lora_dim)
+ wq_b_pe = wq_b[:, self.qk_nope_head_dim :, :].reshape(-1, self.q_lora_dim)
+
+ wq_b_nope = wq_b_nope.reshape(nope_ctas, rows_per_cta // 16, 16, self.q_lora_dim)
+ wq_b_nope = RmsnormProjqWqbWeightsConverter._swizzle_mma_16x16_for_pages(
+ wq_b_nope, self.q_lora_dim, pages
+ )
+ wq_b_nope = (
+ wq_b_nope.reshape(nope_ctas, rows_per_cta // 16, pages, 16, -1)
+ .transpose(1, 2)
+ .reshape(nope_ctas, pages, rows_per_cta, -1)
+ )
+
+ wq_b_pe = wq_b_pe.reshape(pe_ctas, rows_per_cta // 16, 16, self.q_lora_dim)
+ wq_b_pe = RmsnormProjqWqbWeightsConverter._swizzle_mma_16x16_for_pages(
+ wq_b_pe, self.q_lora_dim, pages
+ )
+ wq_b_pe = (
+ wq_b_pe.reshape(pe_ctas, rows_per_cta // 16, pages, 16, -1)
+ .transpose(1, 2)
+ .reshape(pe_ctas, pages, rows_per_cta, -1)
+ )
+
+ weights = torch.cat([wq_b_nope, wq_b_pe], dim=0)
+ weights = weights.reshape(num_ctas, pages, -1)
+
+ scale_padding_size = 128 - scales_fp8.shape[-1]
+ scale_padding = torch.zeros(
+ num_ctas,
+ pages,
+ scale_padding_size,
+ dtype=torch.float8_e4m3fn,
+ device=wq_b.device,
+ )
+ tilert_wqb = torch.cat([weights, scales_fp8, scale_padding], dim=-1).contiguous()
+
+ tilert_wqb_scales = torch.zeros(1, dtype=torch.bfloat16)
+ tilert_gamma = rmsnorm_gamma.float().detach().clone()
+ return tilert_wqb, tilert_wqb_scales, tilert_gamma
+
+ def convert_to_fp16mma(
+ self, weights: list[torch.Tensor]
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Convert common-format weights to TileRT FP16 MMA layout."""
+ with torch.inference_mode():
+ wq_b, wq_b_scale, q_norm_weight = weights
+ return self._common_to_tilert_fp16mma(wq_b, wq_b_scale, q_norm_weight)
+
+
+@dataclass
+class RmsnormProjqWqbRefWeightsAlias:
+ """Reference weights alias for RmsnormProjqWqb."""
+
+ rmsnorm_gamma = "self_attn.q_a_layernorm.weight"
+ wqb_weights = "self_attn.q_b_proj.weight"
+ wqb_scales = "self_attn.q_b_proj.weight_scale_inv"
+
+ @property
+ def ref_tensor_alias(self) -> list[str]:
+ return [
+ self.rmsnorm_gamma,
+ self.wqb_weights,
+ self.wqb_scales,
+ ]
+
+ def __call__(self) -> list[str]:
+ return self.ref_tensor_alias
+
+
+@dataclass
+class RmsnormProjqWqbTilertWeightsAlias:
+ """TileRT weights alias for RmsnormProjqWqb."""
+
+ rmsnorm_gamma = "q_rmsnorm_gamma"
+ wqb_weights = "wqb_weights"
+ wqb_scales = "wqb_scales"
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return [
+ self.rmsnorm_gamma,
+ self.wqb_weights,
+ self.wqb_scales,
+ ]
+
+ def __call__(self) -> list[str]:
+ return self.tilert_tensor_alias
+
+
+class RmsnormProjqWqb(TileRTModule):
+ """RmsnormProjqWqb module: RMSNorm + Q projection (wq_b only)."""
+
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [RmsnormProjqWqbAlgorithm.FP16MMA],
+ "glm_5": [RmsnormProjqWqbAlgorithm.FP16MMA],
+ }
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ device_id: int,
+ num_devices: int = 7,
+ ref_weights_alias: RmsnormProjqWqbRefWeightsAlias | None = None,
+ ):
+ super().__init__(
+ self.__class__.__name__,
+ model_args=model_args,
+ device_id=device_id,
+ num_devices=num_devices,
+ )
+
+ self.tilert_weights_alias = RmsnormProjqWqbTilertWeightsAlias()
+ self.ref_weights_alias = (
+ ref_weights_alias if ref_weights_alias is not None else RmsnormProjqWqbRefWeightsAlias()
+ )
+
+ self.n_local_heads = RmsnormProjqWqbWeightsConverter._compute_n_local_heads(
+ model_args.n_heads,
+ num_devices,
+ model_args.qk_nope_head_dim + model_args.qk_rope_head_dim,
+ )
+ self.q_lora_rank = model_args.q_lora_rank
+ self.n_heads = model_args.n_heads
+ self.qk_nope_head_dim = model_args.qk_nope_head_dim
+ self.qk_rope_head_dim = model_args.qk_rope_head_dim
+ self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
+ self.qk_local_dim = self.qk_head_dim * self.n_local_heads
+
+ self.block_size = model_args.block_size
+ self.q_lora_qdim = self.q_lora_rank // self.block_size
+ self.qk_local_qdim = self.qk_local_dim // self.block_size
+ self.eps = model_args.eps
+
+ self.ref_q_norm: torch.Tensor | None = None
+ self.ref_wq_b: torch.Tensor | None = None
+
+ self.tilert_wq_b: torch.Tensor | None = None
+ self.tilert_wq_b_scales: torch.Tensor | None = None
+ self.tilert_q_norm_weight: torch.Tensor | None = None
+
+ self.q_nope: torch.Tensor | None = None
+ self.q_pe: torch.Tensor | None = None
+
+ self.profile_logs: torch.Tensor | None = None
+
+ def get_weights_list(self) -> list[torch.Tensor]:
+ return [self.tilert_q_norm_weight, self.tilert_wq_b, self.tilert_wq_b_scales]
+
+ def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ """Redistribute heads across devices with padding."""
+ gamma = weights_map[self.ref_weights_alias.rmsnorm_gamma][None, ...].repeat(
+ self.num_devices, 1
+ )
+
+ wq_b_full = weights_map[self.ref_weights_alias.wqb_weights]
+ wq_b_scale_full = weights_map[self.ref_weights_alias.wqb_scales]
+
+ wq_b_list, scale_list = RmsnormProjqWqbWeightsConverter._redistribute_heads(
+ wq_b_full,
+ wq_b_scale_full,
+ n_total_heads=self.n_heads,
+ n_local_heads=self.n_local_heads,
+ num_devices=self.num_devices,
+ qk_head_dim=self.qk_head_dim,
+ block_size=self.block_size,
+ )
+
+ sharded_wqb_weights = torch.stack(wq_b_list, dim=0)
+ sharded_wqb_scales = torch.stack(scale_list, dim=0)
+
+ return {
+ self.tilert_weights_alias.rmsnorm_gamma: gamma,
+ self.tilert_weights_alias.wqb_weights: sharded_wqb_weights,
+ self.tilert_weights_alias.wqb_scales: sharded_wqb_scales,
+ }
+
+ def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ """Initialize reference weights from common-format state dict."""
+ self.ref_q_norm = state_dict[self.ref_weights_alias.rmsnorm_gamma]
+
+ wq_b_full = state_dict[self.ref_weights_alias.wqb_weights]
+ wq_b_scale_full = state_dict[self.ref_weights_alias.wqb_scales]
+
+ wq_b_bf16_full = weight_dequant(wq_b_full, wq_b_scale_full)
+
+ total_rows = self.n_heads * self.qk_head_dim
+ rows_per_dev = self.n_local_heads * self.qk_head_dim
+ start_row = self.device_id * rows_per_dev
+ end_row = min(total_rows, start_row + rows_per_dev)
+ real_rows = max(0, end_row - start_row)
+
+ dev_wqb = torch.zeros(
+ rows_per_dev,
+ wq_b_bf16_full.shape[-1],
+ dtype=wq_b_bf16_full.dtype,
+ device=wq_b_bf16_full.device,
+ )
+ if real_rows > 0:
+ dev_wqb[:real_rows] = wq_b_bf16_full[start_row:end_row]
+
+ self.ref_wq_b = dev_wqb.contiguous()
+
+ def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ """Initialize TileRT weights from common-format state dict."""
+ weights = [
+ state_dict[self.tilert_weights_alias.wqb_weights],
+ state_dict[self.tilert_weights_alias.wqb_scales],
+ state_dict[self.tilert_weights_alias.rmsnorm_gamma],
+ ]
+ assert self.algorithm is not None, "Algorithm is not set"
+ self.tilert_wq_b, self.tilert_wq_b_scales, self.tilert_q_norm_weight = (
+ RmsnormProjqWqbWeightsConverter(self.model_args, self.num_devices).dispatch(
+ self.algorithm, weights
+ )
+ )
+
+ def init_random_weights(self) -> None:
+ """Initialize random reference and TileRT weights for testing."""
+ q_norm = torch.randn(self.q_lora_rank, dtype=torch.float32)
+
+ wq_b = torch.randn(self.qk_local_dim, self.q_lora_rank, dtype=torch.bfloat16).to(
+ torch.float8_e4m3fn
+ )
+ scale_dtype = torch.float32 if self.model_args.arch_name == "glm_5" else torch.bfloat16
+ wq_b_scale = torch.randn(self.qk_local_qdim, self.q_lora_qdim, dtype=scale_dtype)
+
+ self.ref_q_norm = q_norm
+ self.ref_wq_b = weight_dequant(wq_b, wq_b_scale).contiguous()
+
+ assert self.algorithm is not None, "Algorithm is not set"
+ weights = [wq_b, wq_b_scale, q_norm]
+ self.tilert_wq_b, self.tilert_wq_b_scales, self.tilert_q_norm_weight = (
+ RmsnormProjqWqbWeightsConverter(self.model_args, self.num_devices).dispatch(
+ self.algorithm, weights
+ )
+ )
+
+ def init_tilert_vars(self, batch_size: int, seq_len: int) -> None:
+ """Allocate TileRT output buffers."""
+ self.q_nope = torch.zeros(
+ batch_size, seq_len, self.n_local_heads, self.qk_nope_head_dim, dtype=torch.bfloat16
+ )
+ self.q_pe = torch.zeros(
+ batch_size, seq_len, self.n_local_heads, self.qk_rope_head_dim, dtype=torch.bfloat16
+ )
+ self.profile_logs = get_profile_log_tensor()
+ self.is_var_init = True
+
+ def golden_forward(self, q: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ """Reference forward: RMSNorm + linear projection (no iq)."""
+ assert self.ref_q_norm is not None
+ assert self.ref_wq_b is not None
+
+ bsz, seqlen, _ = q.shape
+ if bsz != 1 or seqlen not in [1, 2, 4]:
+ raise ValueError(f"Invalid batch size or sequence length: bsz={bsz}, seqlen={seqlen}")
+
+ qr = torch.nn.functional.rms_norm(q.float(), [q.size(-1)], self.ref_q_norm, self.eps).to(
+ q.dtype
+ )
+
+ q_out = torch.matmul(qr, self.ref_wq_b.T)
+ q_out = q_out.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
+ q_nope, q_pe = torch.split(q_out, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
+ return q_nope, q_pe
+
+ def tilert_forward(self, q: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ assert self.tilert_wq_b is not None
+ assert self.tilert_wq_b_scales is not None
+ assert self.tilert_q_norm_weight is not None
+ assert self.q_nope is not None
+ assert self.q_pe is not None
+ assert self.profile_logs is not None
+
+ bsz, seqlen, _ = q.shape
+ if bsz != 1 or seqlen not in [1, 2, 4]:
+ raise ValueError(f"Invalid batch size or sequence length: bsz={bsz}, seqlen={seqlen}")
+
+ assert self.algorithm is not None, "Algorithm is not set"
+
+ rmsnorm_projq_wqb_op(
+ q,
+ self.tilert_wq_b,
+ self.tilert_wq_b_scales,
+ self.tilert_q_norm_weight,
+ self.q_nope,
+ self.q_pe,
+ self.profile_logs,
+ self.algorithm.value,
+ model_arch=self.model_args.arch_name,
+ )
+
+ return self.q_nope, self.q_pe
diff --git a/tilert/models/glm_5/_dsa_v32/ops/rmsnorm_projq_wqi.py b/tilert/models/glm_5/_dsa_v32/ops/rmsnorm_projq_wqi.py
new file mode 100644
index 0000000..4f4d07f
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/ops/rmsnorm_projq_wqi.py
@@ -0,0 +1,330 @@
+"""RmsnormProjqWqi operation module (IQ-only projection)."""
+
+from dataclasses import dataclass
+from enum import Enum
+
+import torch
+from einops import rearrange
+
+from tilert.models.base import TileRTModule, TilertWeightsConverter
+from tilert.models.common import weight_dequant
+from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
+from tilert.utils import get_profile_log_tensor
+
+__all__ = [
+ "RmsnormProjqWqi",
+ "RmsnormProjqWqiAlgorithm",
+ "RmsnormProjqWqiWeightsConverter",
+]
+
+
+def rmsnorm_projq_wqi_op(
+ q: torch.Tensor,
+ wqi: torch.Tensor,
+ wqi_scale: torch.Tensor,
+ rmsnorm_gamma: torch.Tensor,
+ iq: torch.Tensor,
+ profile_logs: torch.Tensor,
+ algorithm: str,
+ model_arch: str,
+) -> None:
+ torch.ops.tilert.rmsnorm_proj_qi_op(
+ q,
+ wqi,
+ wqi_scale,
+ rmsnorm_gamma,
+ iq,
+ model_arch,
+ algorithm,
+ profile_logs,
+ )
+
+
+class RmsnormProjqWqiAlgorithm(Enum):
+ """RmsnormProjqWqi algorithm."""
+
+ FP16MMA = "fp16mma"
+
+
+class RmsnormProjqWqiWeightsConverter(TilertWeightsConverter):
+ """Weights converter: common format to TileRT format (IQ only)."""
+
+ def __init__(self, model_args: ModelArgs, num_devices: int):
+ super().__init__(model_args=model_args, num_devices=num_devices)
+
+ self.block_size = self.model_args.block_size
+ self.q_lora_dim = self.model_args.q_lora_rank
+ self.q_lora_qdim = self.q_lora_dim // self.block_size
+
+ self.index_n_heads = self.model_args.index_n_heads
+ self.index_head_dim = self.index_n_heads * self.model_args.index_head_dim
+ self.index_head_qdim = self.index_head_dim // self.block_size
+
+ @staticmethod
+ def _swizzle_mma_16x16(mat_in: torch.Tensor) -> torch.Tensor:
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 16
+ pre_shape = mat_in.shape[:-2]
+ mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 2).transpose(-4, -3).transpose(-5, -4)
+ return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 2).transpose(-3, -2)
+
+ @staticmethod
+ def _swizzle_mma_16x16_for_pages(
+ mat_in: torch.Tensor, q_lora_rank: int, pages: int
+ ) -> torch.Tensor:
+ """Swizzle a 16xK matrix for the paged weight layout (K divisible by 16)."""
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == q_lora_rank
+ pre_shape = mat_in.shape[:-2]
+ k_per_page = q_lora_rank // pages
+ n_k_tiles = k_per_page // 16
+ mat_in = mat_in.reshape(*pre_shape, 16, pages, k_per_page).transpose(-3, -2)
+ mat_in = mat_in.reshape(*pre_shape, pages, 16, n_k_tiles, 16).transpose(-3, -2)
+ mat_in = RmsnormProjqWqiWeightsConverter._swizzle_mma_16x16(mat_in)
+ return mat_in.contiguous()
+
+ def _common_to_tilert_fp16mma(
+ self,
+ wqi: torch.Tensor,
+ wqi_scales: torch.Tensor,
+ rmsnorm_gamma: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Convert common weights to the packed TileRT FP16 layout (IQ only)."""
+ sms = 128
+ k_per_page = 1024 if self.model_args.arch_name == "glm_5" else 512
+ pages = self.q_lora_dim // k_per_page
+ iq_dim_per_sm = self.index_head_dim // sms
+
+ wqi_scales_f32 = wqi_scales.to(torch.float32)
+ wqi_scales_f32 = (
+ wqi_scales_f32.reshape(self.index_head_qdim, 1, self.q_lora_qdim)
+ .repeat(1, self.block_size, 1)
+ .reshape(self.index_head_dim, self.q_lora_qdim)
+ )
+ wqi_scales_f32 = wqi_scales_f32.reshape(
+ sms, iq_dim_per_sm, pages, self.q_lora_qdim // pages
+ ).transpose(1, 2)
+ wqi_scales_f32 = wqi_scales_f32[:, :, 0, :]
+ wqi_full_scales = wqi_scales_f32.contiguous().view(torch.float8_e4m3fn)
+
+ wqi_mat = wqi.reshape(sms, iq_dim_per_sm // 16, 16, self.q_lora_dim)
+ wqi_mat = RmsnormProjqWqiWeightsConverter._swizzle_mma_16x16_for_pages(
+ wqi_mat, self.q_lora_dim, pages
+ )
+ wqi_mat = (
+ wqi_mat.reshape(sms, iq_dim_per_sm // 16, pages, 16, -1)
+ .transpose(1, 2)
+ .reshape(sms, pages, iq_dim_per_sm, -1)
+ )
+ wqi_mat = wqi_mat.reshape(sms, pages, -1)
+
+ wqi_scales_padding = torch.zeros(
+ sms,
+ pages,
+ 128 - wqi_full_scales.shape[-1],
+ dtype=torch.float8_e4m3fn,
+ device=wqi.device,
+ )
+ tilert_wqi = torch.cat([wqi_mat, wqi_full_scales, wqi_scales_padding], dim=-1).contiguous()
+ tilert_wqi_scales = torch.zeros(1, dtype=torch.bfloat16)
+ tilert_gamma = rmsnorm_gamma.float().detach().clone()
+ return tilert_wqi, tilert_wqi_scales, tilert_gamma
+
+ def convert_to_fp16mma(
+ self, weights: list[torch.Tensor]
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Convert common-format weights to TileRT FP16 MMA layout.
+
+ Args:
+ weights: [wqi, wqi_scale, q_norm_weight].
+ """
+ with torch.inference_mode():
+ wqi, wqi_scale, q_norm_weight = weights
+ return self._common_to_tilert_fp16mma(wqi, wqi_scale, q_norm_weight)
+
+
+@dataclass
+class RmsnormProjqWqiRefWeightsAlias:
+ """Reference (HuggingFace) weights alias for RmsnormProjqWqi."""
+
+ rmsnorm_gamma = "self_attn.q_a_layernorm.weight"
+ wqi_weights = "self_attn.indexer.wq_b.weight"
+ wqi_scales = "self_attn.indexer.wq_b.weight_scale_inv"
+
+ @property
+ def ref_tensor_alias(self) -> list[str]:
+ return [self.rmsnorm_gamma, self.wqi_weights, self.wqi_scales]
+
+ def __call__(self) -> list[str]:
+ return self.ref_tensor_alias
+
+
+@dataclass
+class RmsnormProjqWqiTilertWeightsAlias:
+ """TileRT weights alias for RmsnormProjqWqi."""
+
+ rmsnorm_gamma = "q_rmsnorm_gamma_qi"
+ wqi_weights = "wqi_weights"
+ wqi_scales = "wqi_scales"
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return [self.rmsnorm_gamma, self.wqi_weights, self.wqi_scales]
+
+ def __call__(self) -> list[str]:
+ return self.tilert_tensor_alias
+
+
+class RmsnormProjqWqi(TileRTModule):
+ """RmsnormProjqWqi module: RMSNorm + W_qi projection (IQ only, GLM5 v2)."""
+
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [RmsnormProjqWqiAlgorithm.FP16MMA],
+ "glm_5": [RmsnormProjqWqiAlgorithm.FP16MMA],
+ }
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ device_id: int,
+ num_devices: int,
+ ):
+ super().__init__(
+ self.__class__.__name__,
+ model_args=model_args,
+ device_id=device_id,
+ num_devices=num_devices,
+ )
+
+ self.tilert_weights_alias = RmsnormProjqWqiTilertWeightsAlias()
+ self.ref_weights_alias = RmsnormProjqWqiRefWeightsAlias()
+
+ self.q_lora_rank = model_args.q_lora_rank
+ self.index_n_heads = model_args.index_n_heads
+ self.head_dim = model_args.index_head_dim
+ self.index_head_dim = model_args.index_n_heads * model_args.index_head_dim
+
+ self.block_size = model_args.block_size
+ self.q_lora_qdim = self.q_lora_rank // self.block_size
+ self.index_head_qdim = self.index_head_dim // self.block_size
+ self.eps = model_args.eps
+
+ self.ref_q_norm: torch.Tensor | None = None
+ self.ref_wqi: torch.Tensor | None = None
+
+ self.tilert_wqi: torch.Tensor | None = None
+ self.tilert_wqi_scales: torch.Tensor | None = None
+ self.tilert_q_norm_weight: torch.Tensor | None = None
+
+ self.iq: torch.Tensor | None = None
+ self.profile_logs: torch.Tensor | None = None
+
+ def get_weights_list(self) -> list[torch.Tensor]:
+ return [self.tilert_q_norm_weight, self.tilert_wqi, self.tilert_wqi_scales]
+
+ def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ """Replicate IQ weights across devices (no per-head redistribution needed)."""
+ gamma = (
+ weights_map[self.ref_weights_alias.rmsnorm_gamma][None, ...]
+ .float()
+ .repeat(self.num_devices, 1)
+ )
+ wqi_weights = weights_map[self.ref_weights_alias.wqi_weights][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ wqi_scales = weights_map[self.ref_weights_alias.wqi_scales][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ return {
+ self.tilert_weights_alias.rmsnorm_gamma: gamma,
+ self.tilert_weights_alias.wqi_weights: wqi_weights,
+ self.tilert_weights_alias.wqi_scales: wqi_scales,
+ }
+
+ def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ """Initialize reference weights from common-format state dict."""
+ self.ref_q_norm = state_dict[self.tilert_weights_alias.rmsnorm_gamma]
+ wqi = weight_dequant(
+ state_dict[self.tilert_weights_alias.wqi_weights],
+ state_dict[self.tilert_weights_alias.wqi_scales],
+ )
+ self.ref_wqi = wqi.contiguous()
+
+ def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ """Initialize TileRT weights from common-format state dict."""
+ weights = [
+ state_dict[self.tilert_weights_alias.wqi_weights],
+ state_dict[self.tilert_weights_alias.wqi_scales],
+ state_dict[self.tilert_weights_alias.rmsnorm_gamma],
+ ]
+ assert self.algorithm is not None, "Algorithm is not set"
+ self.tilert_wqi, self.tilert_wqi_scales, self.tilert_q_norm_weight = (
+ RmsnormProjqWqiWeightsConverter(self.model_args, self.num_devices).dispatch(
+ self.algorithm, weights
+ )
+ )
+
+ def init_random_weights(self) -> None:
+ """Initialize random reference and TileRT weights for testing."""
+ q_norm = torch.randn(self.q_lora_rank, dtype=torch.float32)
+ wqi = torch.randn(self.index_head_dim, self.q_lora_rank, dtype=torch.bfloat16).to(
+ torch.float8_e4m3fn
+ )
+ scale_dtype = torch.float32 if self.model_args.arch_name == "glm_5" else torch.bfloat16
+ wqi_scale = torch.randn(self.index_head_qdim, self.q_lora_qdim, dtype=scale_dtype)
+
+ ref_state = {
+ self.tilert_weights_alias.rmsnorm_gamma: q_norm,
+ self.tilert_weights_alias.wqi_weights: wqi,
+ self.tilert_weights_alias.wqi_scales: wqi_scale,
+ }
+
+ self.init_reference_weights(ref_state)
+ self.init_tilert_weights(ref_state)
+
+ def init_tilert_vars(self, batch_size: int, seq_len: int) -> None:
+ """Allocate TileRT output buffers."""
+ self.iq = torch.zeros(
+ batch_size, seq_len, self.index_n_heads, self.head_dim, dtype=torch.bfloat16
+ )
+ self.profile_logs = get_profile_log_tensor()
+ self.is_var_init = True
+
+ def golden_forward(self, q: torch.Tensor) -> torch.Tensor:
+ """Reference forward: RMSNorm + W_qi_b linear projection."""
+ assert self.ref_q_norm is not None
+ assert self.ref_wqi is not None
+
+ bsz, seqlen, _ = q.shape
+ if bsz != 1 or seqlen not in [1, 2, 4, 8]:
+ raise ValueError(f"Invalid batch size or sequence length: bsz={bsz}, seqlen={seqlen}")
+
+ qr = torch.nn.functional.rms_norm(q.float(), [q.size(-1)], self.ref_q_norm, self.eps).to(
+ q.dtype
+ )
+
+ return rearrange(torch.matmul(qr, self.ref_wqi.T), "b s (h d) -> b s h d", d=self.head_dim)
+
+ def tilert_forward(self, q: torch.Tensor) -> torch.Tensor:
+ assert self.tilert_wqi is not None
+ assert self.tilert_wqi_scales is not None
+ assert self.tilert_q_norm_weight is not None
+ assert self.iq is not None
+ assert self.profile_logs is not None
+
+ bsz, seqlen, _ = q.shape
+ if bsz != 1 or seqlen not in [1, 2, 4, 8]:
+ raise ValueError(f"Invalid batch size or sequence length: bsz={bsz}, seqlen={seqlen}")
+
+ assert self.algorithm is not None, "Algorithm is not set"
+
+ rmsnorm_projq_wqi_op(
+ q,
+ self.tilert_wqi,
+ self.tilert_wqi_scales,
+ self.tilert_q_norm_weight,
+ self.iq,
+ self.profile_logs,
+ self.algorithm.value,
+ model_arch=self.model_args.arch_name,
+ )
+
+ return self.iq
diff --git a/tilert/models/glm_5/_dsa_v32/ops/rmsnorm_projx_wqakis.py b/tilert/models/glm_5/_dsa_v32/ops/rmsnorm_projx_wqakis.py
new file mode 100644
index 0000000..8813d6a
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/ops/rmsnorm_projx_wqakis.py
@@ -0,0 +1,341 @@
+"""RMSNormProjxWqakis operation module."""
+
+from enum import Enum
+
+import torch
+
+from tilert.models.base import TileRTModule, TilertWeightsConverter
+from tilert.models.common import weight_dequant
+from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
+from tilert.models.glm_5._dsa_v32.ops.projx_wis import projx_wis
+from tilert.models.glm_5._dsa_v32.ops.projx_wqaki import (
+ ProjxWqakiWeightsConverter,
+ projx_wqaki,
+)
+from tilert.models.glm_5._dsa_v32.ops.rmsnorm_quant import rmsnorm_quant
+from tilert.utils import get_profile_log_tensor
+
+__all__ = [
+ "RMSNormProjxWqakis",
+]
+
+
+class RMSNormProjxWqakisWeightsConverter(TilertWeightsConverter):
+ """Weight converter for RMSNormProjxWqakis."""
+
+ def __init__(self, model_args: ModelArgs, num_devices: int):
+ super().__init__(model_args, num_devices)
+
+ def convert_to_decoupled(
+ self, weights: list[torch.Tensor]
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Convert weights to decoupled FP8 MMA format.
+
+ Args:
+ weights: [gamma, wq_a, wq_a_scale, wki, wki_scale, wis, wis_scale]
+
+ Returns:
+ (wqaki_packed, wis_bf16, gamma)
+ """
+ arch_name = self.model_args.arch_name
+ x_rmsnorm_gamma, wq_a, wq_a_scale, wki, wki_scale, wis, _wis_scale = weights
+
+ if arch_name == "deepseek_v3_2":
+ wqaki_packed = ProjxWqakiWeightsConverter.convert_dsv32(
+ wq_a, wq_a_scale, wki, wki_scale
+ )
+ elif arch_name == "glm_5":
+ wqaki_packed = ProjxWqakiWeightsConverter.convert_glm5_68cta(
+ wq_a, wq_a_scale, wki, wki_scale
+ )
+ else:
+ raise ValueError(f"Unsupported architecture: {arch_name}")
+
+ wis_bf16 = wis.to(torch.bfloat16)
+ return wqaki_packed, wis_bf16, x_rmsnorm_gamma.float()
+
+
+class RMSNormProjxWqakisRefWeightsAlias:
+ """Reference weight aliases for RMSNormProjxWqakis."""
+
+ x_rmsnorm_gamma = "input_layernorm.weight"
+ q_a_weights = "self_attn.q_a_proj.weight"
+ q_a_scales = "self_attn.q_a_proj.weight_scale_inv"
+ wk_weights = "self_attn.indexer.wk.weight"
+ wk_scales = "self_attn.indexer.wk.weight_scale_inv"
+ wis_weights = "self_attn.indexer.weights_proj.weight"
+
+ @property
+ def ref_tensor_alias(self) -> list[str]:
+ return [
+ self.x_rmsnorm_gamma,
+ self.q_a_weights,
+ self.q_a_scales,
+ self.wk_weights,
+ self.wk_scales,
+ self.wis_weights,
+ ]
+
+ def __call__(self) -> list[str]:
+ return self.ref_tensor_alias
+
+
+class RMSNormProjxWqakisTilertWeightsAlias:
+ """Tilert weight aliases for RMSNormProjxWqakis."""
+
+ x_rmsnorm_gamma = "x_rmsnorm_gamma"
+ q_a_weights = "q_a_weights"
+ q_a_scales = "q_a_scales"
+ wk_weights = "wk_weights"
+ wk_scales = "wk_scales"
+ wis_weights = "wis_weights"
+ wis_scales = "wis_scales"
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return [
+ self.x_rmsnorm_gamma,
+ self.q_a_weights,
+ self.q_a_scales,
+ self.wk_weights,
+ self.wk_scales,
+ self.wis_weights,
+ self.wis_scales,
+ ]
+
+ def __call__(self) -> list[str]:
+ return self.tilert_tensor_alias
+
+
+class RMSNormProjxWqakisAlgorithm(Enum):
+ """RMSNormProjxWqakis algorithm."""
+
+ FP8MMA = "fp8mma"
+
+
+class RMSNormProjxWqakis(TileRTModule):
+ """Decoupled RMSNorm + GEMV(W_q_a, W_ki, W_is)."""
+
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [RMSNormProjxWqakisAlgorithm.FP8MMA],
+ "glm_5": [RMSNormProjxWqakisAlgorithm.FP8MMA],
+ }
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ num_devices: int,
+ device_id: int,
+ ref_weights_alias: RMSNormProjxWqakisRefWeightsAlias | None = None,
+ ):
+ super().__init__(
+ self.__class__.__name__,
+ model_args=model_args,
+ num_devices=num_devices,
+ device_id=device_id,
+ )
+
+ self.tilert_weights_alias = RMSNormProjxWqakisTilertWeightsAlias()
+ self.ref_weights_alias = (
+ ref_weights_alias
+ if ref_weights_alias is not None
+ else RMSNormProjxWqakisRefWeightsAlias()
+ )
+
+ self.arch_name = self.model_args.arch_name
+ self.dim = self.model_args.dim
+ self.q_lora_rank = self.model_args.q_lora_rank
+ self.idx_head_dim = self.model_args.index_head_dim
+ self.idx_score_dim = self.model_args.index_n_heads
+ self.block_size = self.model_args.block_size
+ self.eps = self.model_args.eps
+
+ self.ref_norm_gamma: torch.Tensor | None = None
+ self.ref_wq_a: torch.Tensor | None = None
+ self.ref_wki: torch.Tensor | None = None
+ self.ref_wis: torch.Tensor | None = None
+
+ self.tilert_norm_gamma: torch.Tensor | None = None
+ self.tilert_wqakis: torch.Tensor | None = None
+ self.tilert_wis: torch.Tensor | None = None
+
+ self.q_out: torch.Tensor | None = None
+ self.ki_out: torch.Tensor | None = None
+ self.idx_scores_out: torch.Tensor | None = None
+ self.x_rmsnorm_out: torch.Tensor | None = None
+ self.x_rmsnorm_quant_out: torch.Tensor | None = None
+ self.x_rmsnorm_quant_scale_out: torch.Tensor | None = None
+ self.profile_logs: torch.Tensor | None = None
+ self.is_init = False
+
+ if self.arch_name == "glm_5":
+ self.compute_kernel_type = "fp8mma_68cta"
+ else:
+ self.compute_kernel_type = "fp8mma"
+
+ self.tilert_tensor_alias: list[str] = [
+ "x_rmsnorm_gamma",
+ "qakis_weights",
+ "qakis_scales",
+ ]
+
+ def get_weights_list(self) -> list[torch.Tensor]:
+ return [self.tilert_norm_gamma, self.tilert_wqakis, self.tilert_wis]
+
+ def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ """Repeat weights for device sharding."""
+ input_layernorm_weight = (
+ weights_map[self.ref_weights_alias.x_rmsnorm_gamma][None, ...]
+ .float()
+ .repeat(self.num_devices, 1)
+ )
+ q_a_proj_weight = weights_map[self.ref_weights_alias.q_a_weights][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ q_a_proj_weight_scale = weights_map[self.ref_weights_alias.q_a_scales][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ wk_weight = weights_map[self.ref_weights_alias.wk_weights][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ wk_weight_scale = weights_map[self.ref_weights_alias.wk_scales][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ wis_weight = weights_map[self.ref_weights_alias.wis_weights][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ is_n_rows = weights_map[self.ref_weights_alias.wis_weights].shape[0]
+ is_scale_rows = (is_n_rows + self.block_size - 1) // self.block_size
+ is_scale_cols = self.dim // self.block_size
+ wis_weight_scale = torch.ones(
+ self.num_devices, is_scale_rows, is_scale_cols, dtype=torch.bfloat16
+ )
+ return {
+ self.tilert_weights_alias.x_rmsnorm_gamma: input_layernorm_weight,
+ self.tilert_weights_alias.q_a_weights: q_a_proj_weight,
+ self.tilert_weights_alias.q_a_scales: q_a_proj_weight_scale,
+ self.tilert_weights_alias.wk_weights: wk_weight,
+ self.tilert_weights_alias.wk_scales: wk_weight_scale,
+ self.tilert_weights_alias.wis_weights: wis_weight,
+ self.tilert_weights_alias.wis_scales: wis_weight_scale,
+ }
+
+ def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ aliases = self.ref_weights_alias()
+ self.ref_norm_gamma = state_dict[aliases[0]]
+ self.ref_wq_a = weight_dequant(state_dict[aliases[1]], state_dict[aliases[2]])
+ self.ref_wki = weight_dequant(state_dict[aliases[3]], state_dict[aliases[4]])
+ self.ref_wis = state_dict[aliases[5]].to(torch.bfloat16)
+
+ assert self.ref_norm_gamma.shape[-1] == self.dim
+ assert self.ref_wq_a.shape == (self.q_lora_rank, self.dim)
+ assert self.ref_wki.shape == (self.idx_head_dim, self.dim)
+ assert self.ref_wis.shape == (self.idx_score_dim, self.dim)
+
+ def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ tilert_aliases = self.tilert_weights_alias()
+ weights_list = [state_dict[alias] for alias in tilert_aliases]
+ converter = RMSNormProjxWqakisWeightsConverter(self.model_args, self.num_devices)
+ result = converter.convert_to_decoupled(weights_list)
+ self.tilert_wqakis, self.tilert_wis, self.tilert_norm_gamma = result
+
+ def init_tilert_vars(self, batch_size: int, seq_len: int) -> None:
+ self.q_out = torch.zeros((batch_size, seq_len, self.q_lora_rank), dtype=torch.bfloat16)
+ self.ki_out = torch.zeros((batch_size, seq_len, self.idx_head_dim), dtype=torch.bfloat16)
+ self.idx_scores_out = torch.zeros(
+ (batch_size, seq_len, self.idx_score_dim), dtype=torch.bfloat16
+ )
+ self.x_rmsnorm_out = torch.zeros((batch_size, seq_len, self.dim), dtype=torch.bfloat16)
+ self.x_rmsnorm_quant_out = torch.zeros(
+ (batch_size, seq_len, self.dim), dtype=torch.float8_e4m3fn
+ )
+ self.x_rmsnorm_quant_scale_out = torch.zeros(
+ (batch_size, seq_len, self.dim // self.block_size), dtype=torch.float32
+ )
+ self.profile_logs = get_profile_log_tensor()
+ self.is_init = True
+
+ def init_random_weights(self) -> None:
+ bs = self.block_size
+ dim_scale_dim = self.dim // bs
+ q_scale_dim = (self.q_lora_rank + bs - 1) // bs
+ ki_scale_dim = (self.idx_head_dim + bs - 1) // bs
+ scale_dtype = torch.float32 if self.arch_name == "glm_5" else torch.bfloat16
+
+ tensor_list = [
+ torch.randn(self.dim, dtype=torch.float32),
+ torch.randn(self.q_lora_rank, self.dim, dtype=torch.bfloat16).to(torch.float8_e4m3fn),
+ torch.randn(q_scale_dim, dim_scale_dim, dtype=scale_dtype),
+ torch.randn(self.idx_head_dim, self.dim, dtype=torch.bfloat16).to(torch.float8_e4m3fn),
+ torch.randn(ki_scale_dim, dim_scale_dim, dtype=scale_dtype),
+ torch.randn(self.idx_score_dim, self.dim, dtype=torch.bfloat16),
+ ]
+ ref_state_dict = dict(zip(self.ref_weights_alias(), tensor_list))
+ self.init_reference_weights(ref_state_dict)
+ self.init_tilert_weights(
+ {_k: _v[self.device_id] for _k, _v in self.device_sharding(ref_state_dict).items()}
+ )
+
+ def golden_forward(
+ self,
+ x: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Pure PyTorch reference: RMSNorm -> q, ki, idx_scores."""
+ assert self.ref_norm_gamma is not None
+ assert self.ref_wq_a is not None
+ assert self.ref_wki is not None
+ assert self.ref_wis is not None
+
+ x_rmsnorm = torch.nn.functional.rms_norm(
+ x.float(), [x.size(-1)], self.ref_norm_gamma, self.eps
+ )
+ q_out = torch.matmul(x_rmsnorm.float(), self.ref_wq_a.transpose(0, 1).float())
+ ki_out = torch.matmul(x_rmsnorm.float(), self.ref_wki.transpose(0, 1).float())
+ idx_scores_out = torch.matmul(x_rmsnorm.float(), self.ref_wis.transpose(0, 1).float())
+ return (
+ q_out.to(torch.bfloat16),
+ ki_out.to(torch.bfloat16),
+ idx_scores_out.to(torch.bfloat16),
+ )
+
+ def tilert_forward(
+ self,
+ x: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Run RMSNorm + ProjXWqaki + ProjXWis via TileRT CUDA kernels."""
+ rmsnorm_quant(
+ x.to(torch.bfloat16),
+ self.tilert_norm_gamma,
+ self.x_rmsnorm_out,
+ self.x_rmsnorm_quant_out,
+ self.x_rmsnorm_quant_scale_out,
+ self.profile_logs,
+ model_arch=self.model_args.arch_name,
+ )
+ projx_wqaki(
+ self.x_rmsnorm_quant_out,
+ self.x_rmsnorm_quant_scale_out,
+ self.tilert_wqakis,
+ self.q_out,
+ self.ki_out,
+ self.profile_logs,
+ self.compute_kernel_type,
+ model_arch=self.model_args.arch_name,
+ )
+ wis_compute_kernel_type = "bf16"
+ projx_wis(
+ self.x_rmsnorm_out,
+ self.tilert_wis,
+ self.idx_scores_out,
+ wis_compute_kernel_type,
+ self.profile_logs,
+ model_arch=self.model_args.arch_name,
+ )
+
+ return self.q_out, self.ki_out, self.idx_scores_out
+
+ def __call__(
+ self,
+ x: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ return self.golden_forward(x)
diff --git a/tilert/models/glm_5/_dsa_v32/ops/rmsnorm_projx_wqkva.py b/tilert/models/glm_5/_dsa_v32/ops/rmsnorm_projx_wqkva.py
new file mode 100644
index 0000000..5343357
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/ops/rmsnorm_projx_wqkva.py
@@ -0,0 +1,516 @@
+"""RMSNormProjxWqkva operation module."""
+
+from enum import Enum
+
+import torch
+
+from tilert.models.base import TileRTModule, TilertWeightsConverter
+from tilert.models.common import weight_dequant
+from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
+from tilert.utils import get_profile_log_tensor
+
+__all__ = [
+ "RMSNormProjxWqkva",
+ "RMSNormProjxWqkvaAlgorithm",
+]
+
+
+class RMSNormProjQKVAFP8MMAWeightsConverter:
+ """Weight converter: pack FP8 weights into the kernel's packed layout."""
+
+ HIDDEN_DIM = 6144
+ Q_LORA_RANK = 2048
+ KV_LORA_RANK = 512
+ QK_ROPE_HEAD_DIM = 64
+ TOTAL_ROWS = Q_LORA_RANK + KV_LORA_RANK + QK_ROPE_HEAD_DIM
+ ROWS_PER_CTA = 32
+ NUM_CTAS = TOTAL_ROWS // ROWS_PER_CTA
+ COLS_PER_PAGE = 1024
+ NUM_PAGES = HIDDEN_DIM // COLS_PER_PAGE
+ SCALES_PER_PAGE = COLS_PER_PAGE // 128
+ BLOCK_SIZE = 128
+
+ MAT_BYTES = ROWS_PER_CTA * COLS_PER_PAGE
+ SCALE_OFFSET = MAT_BYTES
+ PAGE_BYTES = ((MAT_BYTES + 128 + 127) // 128) * 128
+
+ @staticmethod
+ def _swizzle_mma_16x32(mat_in: torch.Tensor) -> torch.Tensor:
+ """Swizzle [*, 16, 32] tiles into the packed weight layout."""
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 32
+ pre_shape = mat_in.shape[:-2]
+ mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 4).transpose(-4, -3).transpose(-5, -4)
+ return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 4).transpose(-3, -2)
+
+ @staticmethod
+ def convert_to_fp8_mma_gemv(
+ wq_a: torch.Tensor,
+ wq_a_scale: torch.Tensor,
+ wkv_a: torch.Tensor,
+ wkv_a_scale: torch.Tensor,
+ w_pe: torch.Tensor,
+ w_pe_scale: torch.Tensor,
+ attn_norm_weight: torch.Tensor,
+ *,
+ hidden_dim: int = 6144,
+ q_lora_rank: int = 2048,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Pack FP8 weights for the FP8 MMA kernel.
+
+ Args:
+ hidden_dim: Model hidden dimension.
+ q_lora_rank: Q projection rank.
+ """
+ C = RMSNormProjQKVAFP8MMAWeightsConverter
+ block_size = C.BLOCK_SIZE
+ kv_lora_rank = C.KV_LORA_RANK
+ qk_rope_head_dim = C.QK_ROPE_HEAD_DIM
+
+ expected = q_lora_rank * hidden_dim
+ assert wq_a.numel() == expected, f"wq_a numel {wq_a.numel()} != expected {expected}"
+ expected = kv_lora_rank * hidden_dim
+ assert wkv_a.numel() == expected, f"wkv_a numel {wkv_a.numel()} != expected {expected}"
+ expected = qk_rope_head_dim * hidden_dim
+ assert w_pe.numel() == expected, f"w_pe numel {w_pe.numel()} != expected {expected}"
+
+ total_rows = q_lora_rank + kv_lora_rank + qk_rope_head_dim
+ num_ctas = total_rows // C.ROWS_PER_CTA
+ num_pages = hidden_dim // C.COLS_PER_PAGE
+
+ wq_a_f = weight_dequant(wq_a.reshape(q_lora_rank, hidden_dim), wq_a_scale)
+ wkv_a_f = weight_dequant(wkv_a.reshape(kv_lora_rank, hidden_dim), wkv_a_scale)
+ w_pe_f = weight_dequant(w_pe.reshape(qk_rope_head_dim, hidden_dim), w_pe_scale)
+ w_float = torch.cat([wq_a_f, wkv_a_f, w_pe_f], dim=0)
+
+ w_blocks = w_float.reshape(total_rows, hidden_dim // block_size, block_size)
+ col_max = w_blocks.abs().amax(dim=(0, 2))
+ fp8_max = torch.finfo(torch.float8_e4m3fn).max
+ w_scales = (col_max / fp8_max).clamp(min=1e-12)
+
+ scales_expanded = w_scales.repeat_interleave(block_size)
+ w_scaled = w_float / scales_expanded.unsqueeze(0)
+ w_fp8 = w_scaled.to(torch.float8_e4m3fn)
+
+ assert C.MAT_BYTES == C.SCALE_OFFSET, "Layout mismatch: scales must follow mat"
+ assert block_size == C.COLS_PER_PAGE // C.SCALES_PER_PAGE, "Block size mismatch"
+ assert w_scales.numel() == num_pages * C.SCALES_PER_PAGE, "Scale count mismatch"
+
+ w_bytes = w_fp8.view(torch.uint8)
+ num_tiles = C.COLS_PER_PAGE // 32
+
+ mat = w_bytes.reshape(num_ctas, C.ROWS_PER_CTA, num_pages, C.COLS_PER_PAGE)
+ mat = mat.transpose(1, 2)
+
+ mat = mat.reshape(num_ctas, num_pages, 2, 16, num_tiles, 32)
+ mat = mat.transpose(3, 4)
+ mat = C._swizzle_mma_16x32(mat)
+ mat = mat.contiguous().reshape(num_ctas, num_pages, C.MAT_BYTES)
+
+ scales_f32 = w_scales.reshape(num_pages, C.SCALES_PER_PAGE).to(torch.float32).contiguous()
+ scales_bytes = scales_f32.view(torch.uint8)
+ scales_bytes = scales_bytes.unsqueeze(0).expand(num_ctas, -1, -1)
+
+ pad_size = C.PAGE_BYTES - C.MAT_BYTES - C.SCALES_PER_PAGE * 4
+ padding = torch.zeros(num_ctas, num_pages, pad_size, dtype=torch.uint8, device=w_fp8.device)
+
+ packed = torch.cat([mat, scales_bytes, padding], dim=-1)
+ packed = packed.contiguous().reshape(-1)
+
+ return packed.view(torch.float8_e4m3fn), attn_norm_weight.clone()
+
+
+class RMSNormProjQKVAFP16MMAWeightsConverter:
+ """Weight converter: pack FP16 weights for the kernel."""
+
+ KV_LORA_RANK = 512
+ QK_ROPE_HEAD_DIM = 64
+ ROWS_PER_CTA = 32
+ COLS_PER_PAGE = 512
+ BLOCK_SIZE = 128
+
+ @staticmethod
+ def _swizzle_mma_16x16(mat_in: torch.Tensor) -> torch.Tensor:
+ """Swizzle [*, 16, 16] tiles into the packed weight layout."""
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 16
+ pre_shape = mat_in.shape[:-2]
+ mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 2).transpose(-4, -3).transpose(-5, -4)
+ return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 2).transpose(-3, -2)
+
+ @staticmethod
+ def convert_to_fp16_mma_gemv(
+ wq_a: torch.Tensor,
+ wq_a_scale: torch.Tensor,
+ wkv_a: torch.Tensor,
+ wkv_a_scale: torch.Tensor,
+ w_pe: torch.Tensor,
+ w_pe_scale: torch.Tensor,
+ attn_norm_weight: torch.Tensor,
+ *,
+ hidden_dim: int = 6144,
+ q_lora_rank: int = 2048,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Pack weights into the FP16 layout expected by the kernel."""
+ C = RMSNormProjQKVAFP16MMAWeightsConverter
+ kv_lora_rank = C.KV_LORA_RANK
+ qk_rope_head_dim = C.QK_ROPE_HEAD_DIM
+ cols_per_page = C.COLS_PER_PAGE
+ rows_per_cta = C.ROWS_PER_CTA
+
+ total_rows = q_lora_rank + kv_lora_rank + qk_rope_head_dim
+ num_ctas = total_rows // rows_per_cta
+ num_pages = hidden_dim // cols_per_page
+ num_k_tiles = cols_per_page // 16
+
+ wq_a_f = weight_dequant(wq_a.reshape(q_lora_rank, hidden_dim), wq_a_scale)
+ wkv_a_f = weight_dequant(wkv_a.reshape(kv_lora_rank, hidden_dim), wkv_a_scale)
+ w_pe_f = weight_dequant(w_pe.reshape(qk_rope_head_dim, hidden_dim), w_pe_scale)
+ w_float = torch.cat([wq_a_f, wkv_a_f, w_pe_f], dim=0)
+
+ w_fp16 = w_float.to(torch.float16)
+
+ mat = w_fp16.reshape(num_ctas, rows_per_cta, num_pages, cols_per_page)
+ mat = mat.transpose(1, 2)
+
+ mat = mat.reshape(num_ctas, num_pages, 2, 16, num_k_tiles, 16)
+ mat = mat.transpose(3, 4)
+ mat = C._swizzle_mma_16x16(mat)
+ mat = mat.contiguous()
+
+ mat_bytes = mat.view(torch.uint8).reshape(num_ctas, num_pages, -1)
+ packed = mat_bytes.contiguous().reshape(-1)
+
+ return packed.view(torch.float16), attn_norm_weight.clone()
+
+
+class RMSNormProjxWqkvaAlgorithm(Enum):
+ """RMSNormProjxWqkva algorithm."""
+
+ DECOUPLED = "decoupled"
+
+
+class RMSNormProjxWqkvaWeightsConverter(TilertWeightsConverter):
+ """Dispatch weight converter for RMSNormProjxWqkva."""
+
+ def __init__(self, model_args: ModelArgs, num_devices: int):
+ super().__init__(model_args, num_devices)
+
+ def convert_to_fp8_mma_gemv(
+ self, weights: list[torch.Tensor]
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Convert tilert weights list to the FP8 kernel-ready format.
+
+ Args:
+ weights: [gamma, wq_a, wq_a_scale, wkv_a, wkv_a_scale, w_pe, w_pe_scale]
+ """
+ gamma, wq_a, wq_a_scale, wkv_a, wkv_a_scale, w_pe, w_pe_scale = weights
+ return RMSNormProjQKVAFP8MMAWeightsConverter.convert_to_fp8_mma_gemv(
+ wq_a,
+ wq_a_scale,
+ wkv_a,
+ wkv_a_scale,
+ w_pe,
+ w_pe_scale,
+ gamma,
+ hidden_dim=self.model_args.dim,
+ q_lora_rank=self.model_args.q_lora_rank,
+ )
+
+ def convert_to_fp16_mma_gemv(
+ self, weights: list[torch.Tensor]
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Convert tilert weights list to the FP16 kernel-ready format.
+
+ Args:
+ weights: [gamma, wq_a, wq_a_scale, wkv_a, wkv_a_scale, w_pe, w_pe_scale]
+ """
+ gamma, wq_a, wq_a_scale, wkv_a, wkv_a_scale, w_pe, w_pe_scale = weights
+ return RMSNormProjQKVAFP16MMAWeightsConverter.convert_to_fp16_mma_gemv(
+ wq_a,
+ wq_a_scale,
+ wkv_a,
+ wkv_a_scale,
+ w_pe,
+ w_pe_scale,
+ gamma,
+ hidden_dim=self.model_args.dim,
+ q_lora_rank=self.model_args.q_lora_rank,
+ )
+
+
+class RMSNormProjxWqkvaRefWeightsAlias:
+ """Reference weight aliases for RMSNormProjxWqkva."""
+
+ x_rmsnorm_gamma = "input_layernorm.weight"
+ q_a_weights = "self_attn.q_a_proj.weight"
+ q_a_scales = "self_attn.q_a_proj.weight_scale_inv"
+ kv_a_weights = "self_attn.kv_a_proj_with_mqa.weight"
+ kv_a_scales = "self_attn.kv_a_proj_with_mqa.weight_scale_inv"
+
+ @property
+ def ref_tensor_alias(self) -> list[str]:
+ return [
+ self.x_rmsnorm_gamma,
+ self.q_a_weights,
+ self.q_a_scales,
+ self.kv_a_weights,
+ self.kv_a_scales,
+ ]
+
+ def __call__(self) -> list[str]:
+ return self.ref_tensor_alias
+
+
+class RMSNormProjxWqkvaTilertWeightsAlias:
+ """Tilert weight aliases for RMSNormProjxWqkva."""
+
+ x_rmsnorm_gamma = "x_rmsnorm_gamma"
+ q_a_weights = "q_a_weights"
+ q_a_scales = "q_a_scales"
+ kv_a_weights = "kv_a_weights"
+ kv_a_scales = "kv_a_scales"
+ w_pe_weights = "w_pe_weights"
+ w_pe_scales = "w_pe_scales"
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return [
+ self.x_rmsnorm_gamma,
+ self.q_a_weights,
+ self.q_a_scales,
+ self.kv_a_weights,
+ self.kv_a_scales,
+ self.w_pe_weights,
+ self.w_pe_scales,
+ ]
+
+ def __call__(self) -> list[str]:
+ return self.tilert_tensor_alias
+
+
+class RMSNormProjxWqkva(TileRTModule):
+ """Fused RMSNorm + GEMV(W_q_a, W_kv_a, W_pe)."""
+
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [RMSNormProjxWqkvaAlgorithm.DECOUPLED],
+ "glm_5": [RMSNormProjxWqkvaAlgorithm.DECOUPLED],
+ }
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ num_devices: int,
+ device_id: int,
+ ref_weights_alias: RMSNormProjxWqkvaRefWeightsAlias | None = None,
+ ):
+ super().__init__(
+ self.__class__.__name__,
+ model_args=model_args,
+ num_devices=num_devices,
+ device_id=device_id,
+ )
+
+ self.tilert_weights_alias = RMSNormProjxWqkvaTilertWeightsAlias()
+ self.ref_weights_alias = (
+ ref_weights_alias
+ if ref_weights_alias is not None
+ else RMSNormProjxWqkvaRefWeightsAlias()
+ )
+
+ self.dim = self.model_args.dim
+ self.q_lora_rank = self.model_args.q_lora_rank
+ self.kv_lora_rank = self.model_args.kv_lora_rank
+ self.qk_rope_head_dim = self.model_args.qk_rope_head_dim
+ self.block_size = self.model_args.block_size
+ self.eps = self.model_args.eps
+ self.algorithm = RMSNormProjxWqkvaAlgorithm.DECOUPLED
+
+ self.ref_norm_gamma: torch.Tensor | None = None
+ self.ref_wq_a: torch.Tensor | None = None
+ self.ref_wkv_a: torch.Tensor | None = None
+ self.ref_w_pe: torch.Tensor | None = None
+
+ self.tilert_norm_gamma: torch.Tensor | None = None
+ self.tilert_wqkva: torch.Tensor | None = None
+ self.tilert_wqkva_scales = torch.zeros((1, 1), dtype=torch.bfloat16)
+
+ self.x_rmsnorm_out: torch.Tensor | None = None
+ self.x_rmsnorm_quant_out: torch.Tensor | None = None
+ self.x_rmsnorm_quant_scale_out: torch.Tensor | None = None
+
+ self.q_out: torch.Tensor | None = None
+ self.kv_out: torch.Tensor | None = None
+ self.pe_cache_out: torch.Tensor | None = None
+ self.cur_pos: torch.Tensor | None = None
+ self.profile_logs: torch.Tensor | None = None
+ self.is_init = False
+
+ self.tilert_tensor_alias: list[str] = [
+ "x_rmsnorm_gamma",
+ "qkva_weights",
+ "qkva_scales",
+ ]
+
+ def get_weights_list(self) -> list[torch.Tensor]:
+ return [self.tilert_norm_gamma, self.tilert_wqkva, self.tilert_wqkva_scales]
+
+ def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ """Repeat weights for device sharding."""
+ input_layernorm_weight = (
+ weights_map[self.ref_weights_alias.x_rmsnorm_gamma][None, ...]
+ .float()
+ .repeat(self.num_devices, 1)
+ )
+ q_a_proj_weight = weights_map[self.ref_weights_alias.q_a_weights][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ q_a_proj_weight_scale = weights_map[self.ref_weights_alias.q_a_scales][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ kv_a_mqa = weights_map[self.ref_weights_alias.kv_a_weights]
+ kv_a_proj_weight = kv_a_mqa[: self.kv_lora_rank, :][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ w_pe_weight = kv_a_mqa[self.kv_lora_rank :, :][None, ...].repeat(self.num_devices, 1, 1)
+ kv_a_mqa_scale = weights_map[self.ref_weights_alias.kv_a_scales]
+ kv_scale_rows = (self.kv_lora_rank + self.block_size - 1) // self.block_size
+ kv_a_proj_weight_scale = kv_a_mqa_scale[:kv_scale_rows, :][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ w_pe_weight_scale = kv_a_mqa_scale[kv_scale_rows:, :][None, ...].repeat(
+ self.num_devices, 1, 1
+ )
+ return {
+ self.tilert_weights_alias.x_rmsnorm_gamma: input_layernorm_weight,
+ self.tilert_weights_alias.q_a_weights: q_a_proj_weight,
+ self.tilert_weights_alias.q_a_scales: q_a_proj_weight_scale,
+ self.tilert_weights_alias.kv_a_weights: kv_a_proj_weight,
+ self.tilert_weights_alias.kv_a_scales: kv_a_proj_weight_scale,
+ self.tilert_weights_alias.w_pe_weights: w_pe_weight,
+ self.tilert_weights_alias.w_pe_scales: w_pe_weight_scale,
+ }
+
+ def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ aliases = self.ref_weights_alias()
+ self.ref_norm_gamma = state_dict[aliases[0]]
+ self.ref_wq_a = weight_dequant(state_dict[aliases[1]], state_dict[aliases[2]])
+ kv_a_mqa = weight_dequant(state_dict[aliases[3]], state_dict[aliases[4]])
+ self.ref_wkv_a = kv_a_mqa[: self.kv_lora_rank, :]
+ self.ref_w_pe = kv_a_mqa[self.kv_lora_rank :, :]
+
+ assert self.ref_norm_gamma.shape[-1] == self.dim
+ assert self.ref_wq_a.shape == (self.q_lora_rank, self.dim)
+ assert self.ref_wkv_a.shape == (self.kv_lora_rank, self.dim)
+ assert self.ref_w_pe.shape == (self.qk_rope_head_dim, self.dim)
+
+ def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ tilert_aliases = self.tilert_weights_alias()
+ weights_list = [state_dict[alias] for alias in tilert_aliases]
+ converter = RMSNormProjxWqkvaWeightsConverter(self.model_args, self.num_devices)
+ self.tilert_wqkva, self.tilert_norm_gamma = converter.convert_to_fp8_mma_gemv(weights_list)
+ self.tilert_wqkva_scales = torch.zeros((1,), dtype=torch.float32)
+
+ def init_tilert_vars(self, batch_size: int, seq_len: int, max_len: int = 128) -> None:
+ self.q_out = torch.zeros((batch_size, seq_len, self.q_lora_rank), dtype=torch.bfloat16)
+ self.kv_out = torch.zeros((batch_size, seq_len, self.kv_lora_rank), dtype=torch.bfloat16)
+ self.pe_cache_out = torch.zeros(
+ (batch_size, max_len, self.qk_rope_head_dim), dtype=torch.bfloat16
+ )
+ self.cur_pos = torch.zeros((1,), dtype=torch.int32)
+ self.x_rmsnorm_out = torch.zeros((batch_size, seq_len, self.dim), dtype=torch.bfloat16)
+ self.x_rmsnorm_quant_out = torch.zeros(
+ (batch_size, seq_len, self.dim), dtype=torch.float8_e4m3fn
+ )
+ self.x_rmsnorm_quant_scale_out = torch.zeros(
+ (batch_size, seq_len, self.dim // self.block_size), dtype=torch.float32
+ )
+ self.profile_logs = get_profile_log_tensor()
+ self.is_init = True
+
+ def init_random_weights(self) -> None:
+ bs = self.block_size
+ dim_scale_dim = self.dim // bs
+ q_scale_dim = (self.q_lora_rank + bs - 1) // bs
+ kv_mqa_rows = self.kv_lora_rank + self.qk_rope_head_dim
+ kv_mqa_scale_dim = (kv_mqa_rows + bs - 1) // bs
+ scale_dtype = torch.bfloat16
+
+ tensor_list = [
+ torch.randn(self.dim, dtype=torch.float32),
+ torch.randn(self.q_lora_rank, self.dim, dtype=torch.bfloat16).to(torch.float8_e4m3fn),
+ torch.randn(q_scale_dim, dim_scale_dim, dtype=scale_dtype),
+ torch.randn(kv_mqa_rows, self.dim, dtype=torch.bfloat16).to(torch.float8_e4m3fn),
+ torch.randn(kv_mqa_scale_dim, dim_scale_dim, dtype=scale_dtype),
+ ]
+ ref_state_dict = dict(zip(self.ref_weights_alias(), tensor_list))
+ self.init_reference_weights(ref_state_dict)
+ self.init_tilert_weights(
+ {_k: _v[self.device_id] for _k, _v in self.device_sharding(ref_state_dict).items()}
+ )
+
+ def golden_forward(
+ self,
+ x: torch.Tensor,
+ cur_pos: int = 0, # noqa: U100
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Pure PyTorch reference: RMSNorm -> q, kv, pe."""
+ assert self.ref_norm_gamma is not None
+ assert self.ref_wq_a is not None
+ assert self.ref_wkv_a is not None
+ assert self.ref_w_pe is not None
+
+ x_rmsnorm = torch.nn.functional.rms_norm(
+ x.float(), [x.size(-1)], self.ref_norm_gamma, self.eps
+ )
+ q_out = torch.matmul(x_rmsnorm.float(), self.ref_wq_a.transpose(0, 1).float())
+ kv_out = torch.matmul(x_rmsnorm.float(), self.ref_wkv_a.transpose(0, 1).float())
+ pe_out = torch.matmul(x_rmsnorm.float(), self.ref_w_pe.transpose(0, 1).float())
+ return (
+ q_out.to(torch.bfloat16),
+ kv_out.to(torch.bfloat16),
+ pe_out.to(torch.bfloat16),
+ )
+
+ def tilert_forward(
+ self,
+ x: torch.Tensor,
+ cur_pos: int = 0,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Run RMSNorm + 3-way GEMV via the TileRT CUDA kernels."""
+ assert self.cur_pos is not None
+ assert self.pe_cache_out is not None
+ self.cur_pos.fill_(cur_pos)
+
+ from tilert.models.glm_5._dsa_v32.ops.projx_wqkva import projx_wqkva as _projx_wqkva
+ from tilert.models.glm_5._dsa_v32.ops.rmsnorm_quant import rmsnorm_quant as _rmsnorm_quant
+
+ _rmsnorm_quant(
+ x.to(torch.bfloat16),
+ self.tilert_norm_gamma,
+ self.x_rmsnorm_out,
+ self.x_rmsnorm_quant_out,
+ self.x_rmsnorm_quant_scale_out,
+ self.profile_logs,
+ model_arch=self.model_args.arch_name,
+ )
+ _projx_wqkva(
+ self.x_rmsnorm_quant_out,
+ self.x_rmsnorm_quant_scale_out,
+ self.tilert_wqkva,
+ self.cur_pos,
+ self.q_out,
+ self.kv_out,
+ self.pe_cache_out,
+ self.profile_logs,
+ model_arch=self.model_args.arch_name,
+ )
+
+ seq_len = x.size(-2)
+ pe_at_pos = self.pe_cache_out[:, cur_pos : cur_pos + seq_len, :]
+ return self.q_out, self.kv_out, pe_at_pos
+
+ def __call__(
+ self,
+ x: torch.Tensor,
+ cur_pos: int = 0,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ return self.golden_forward(x, cur_pos)
diff --git a/tilert/models/glm_5/_dsa_v32/ops/rmsnorm_quant.py b/tilert/models/glm_5/_dsa_v32/ops/rmsnorm_quant.py
new file mode 100644
index 0000000..1d399c5
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/ops/rmsnorm_quant.py
@@ -0,0 +1,64 @@
+"""RMSNormQuant operation module."""
+
+from __future__ import annotations
+
+import torch
+
+__all__ = [
+ "BLOCK_SIZE",
+ "DIM_DEEPSEEK_V3_2",
+ "DIM_GLM_5",
+ "rmsnorm_quant",
+]
+
+BLOCK_SIZE = 128
+DIM_DEEPSEEK_V3_2 = 7168
+DIM_GLM_5 = 6144
+
+
+def rmsnorm_quant(
+ hidden_in: torch.Tensor,
+ gamma_in: torch.Tensor,
+ hidden_out: torch.Tensor,
+ quant_hidden_out: torch.Tensor | None = None,
+ quant_hidden_scale_out: torch.Tensor | None = None,
+ profile_logs: torch.Tensor | None = None,
+ compute_kernel_type: str = "general",
+ *,
+ model_arch: str,
+) -> None:
+ """
+ Rmsnorm with optional activation quantization.
+
+ Args:
+ hidden_in: Input tensor (..., dim).
+ gamma_in: RMSNorm gamma (dim,).
+ hidden_out: RMSNorm output (..., dim).
+ quant_hidden_out: Optional quantized output (..., dim). If None, no quant.
+ quant_hidden_scale_out: Optional quant scale (..., dim // block_size). If None, no quant.
+ profile_logs: Optional profile logs tensor.
+ """
+ if profile_logs is None:
+ raise ValueError("profile_logs is required when calling rmsnorm_quant.")
+
+ if quant_hidden_out is None or quant_hidden_scale_out is None:
+ torch.ops.tilert.rmsnorm_op(
+ hidden_in,
+ gamma_in,
+ hidden_out,
+ model_arch,
+ compute_kernel_type,
+ profile_logs,
+ )
+ else:
+ torch.ops.tilert.rmsnorm_quant_op(
+ hidden_in,
+ gamma_in,
+ hidden_out,
+ quant_hidden_out,
+ quant_hidden_scale_out,
+ model_arch,
+ compute_kernel_type,
+ profile_logs,
+ torch.empty(0, dtype=torch.int64, device=hidden_in.device),
+ )
diff --git a/tilert/models/glm_5/_dsa_v32/ops/rmsnorm_up_gate_silu.py b/tilert/models/glm_5/_dsa_v32/ops/rmsnorm_up_gate_silu.py
new file mode 100644
index 0000000..25adae9
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/ops/rmsnorm_up_gate_silu.py
@@ -0,0 +1,363 @@
+"""RMSNormUpGateSiLU operation module."""
+
+from dataclasses import dataclass
+from enum import Enum
+
+import torch
+import torch.nn.functional as F
+
+from tilert.models.base import TileRTModule
+from tilert.models.common import weight_dequant
+from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
+from tilert.models.glm_5._dsa_v32.ops.expert_sel_up_gate_silu import (
+ ExpertSelectUpGateSiLU,
+ ExpertSelectUpGateSiLUWeightsConverter,
+)
+from tilert.utils import get_profile_log_tensor
+
+__all__ = [
+ "RMSNormUpGateSiLUAlgorithm",
+ "RMSNormUpGateSiLU",
+ "RMSNormUpGateSiLUTilertWeightsAlias",
+ "rmsnorm_up_gate_silu",
+]
+
+
+def rmsnorm_up_gate_silu(
+ hidden_in: torch.Tensor,
+ gamma_in: torch.Tensor,
+ weights_in: torch.Tensor,
+ hidden_out: torch.Tensor,
+ profile_logs: torch.Tensor,
+ model_arch: str,
+ compute_kernel_type: str = "fp8mma",
+) -> None:
+ """rmsnorm_up_gate_silu operation."""
+ torch.ops.tilert.rmsnorm_up_gate_silu_op(
+ hidden_in,
+ gamma_in,
+ weights_in,
+ hidden_out,
+ model_arch,
+ compute_kernel_type,
+ profile_logs,
+ )
+
+
+class RMSNormUpGateSiLUAlgorithm(Enum):
+ """RMSNormUpGateSiLU algorithm"""
+
+ FP8MMA = "fp8mma"
+ FP16MMA = "fp16mma"
+
+
+RMSNormUpGateSiLUWeightsConverter = ExpertSelectUpGateSiLUWeightsConverter
+ExpertSelectUpGateSiLUW = ExpertSelectUpGateSiLUWeightsConverter
+
+
+@dataclass
+class RMSNormUpGateSiLUTilertWeightsAlias:
+ """TileRT weights alias for RMSNormUpGateSiLU."""
+
+ unproj_o_gamma = "unproj_o_gamma"
+ gate_weights = "gate_weights"
+ gate_scales = "gate_scales"
+ up_weights = "up_weights"
+ up_scales = "up_scales"
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return [
+ self.unproj_o_gamma,
+ self.gate_weights,
+ self.gate_scales,
+ self.up_weights,
+ self.up_scales,
+ ]
+
+ def __call__(self) -> list[str]:
+ return self.tilert_tensor_alias
+
+
+class RMSNormUpGateSiLU(TileRTModule):
+ """RMSNormUpGateSiLU module"""
+
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [RMSNormUpGateSiLUAlgorithm.FP8MMA, RMSNormUpGateSiLUAlgorithm.FP16MMA],
+ "glm_5": [RMSNormUpGateSiLUAlgorithm.FP8MMA, RMSNormUpGateSiLUAlgorithm.FP16MMA],
+ }
+
+ def __init__(
+ self,
+ model_args: ModelArgs,
+ device_id: int,
+ num_devices: int,
+ algorithm: RMSNormUpGateSiLUAlgorithm = RMSNormUpGateSiLUAlgorithm.FP8MMA,
+ ):
+ super().__init__(
+ self.__class__.__name__,
+ model_args=model_args,
+ device_id=device_id,
+ num_devices=num_devices,
+ )
+
+ self.arch_name = self.model_args.arch_name
+ self.dim = self.model_args.dim
+
+ self.inter_dim = self.model_args.inter_dim
+ self.moe_inter_dim = self.model_args.moe_inter_dim
+ self.moe_inter_dim_per_device = self.moe_inter_dim // self.num_devices
+ self.inter_dim_per_device = self.inter_dim // self.num_devices
+ self.n_experts = self.inter_dim_per_device // self.moe_inter_dim_per_device
+ self.eps = self.model_args.eps
+
+ self.block_size = self.model_args.block_size
+ self.algorithm = algorithm
+
+ self.ref_norm_gamma: torch.Tensor | None = None
+ self.ref_gate: torch.Tensor | None = None
+ self.ref_up: torch.Tensor | None = None
+
+ self.tilert_norm_gamma: torch.Tensor | None = None
+ self.tilert_weights: torch.Tensor | None = None
+ self.tilert_scales = torch.zeros(
+ 9, 4, 64, dtype=torch.bfloat16, device=torch.device("cuda")
+ )
+
+ self.hidden_out: torch.Tensor | None = None
+
+ self.profile_logs: torch.Tensor | None = None
+ self.is_init = False
+
+ self.rmsnorm_up_gate_silu_func = rmsnorm_up_gate_silu
+
+ self.tilert_weights_alias = RMSNormUpGateSiLUTilertWeightsAlias()
+
+ self.ref_tensor_alias: list[str] = [
+ "post_attention_layernorm.weight",
+ "mlp.gate_proj.weight",
+ "mlp.gate_proj.weight_scale_inv",
+ "mlp.up_proj.weight",
+ "mlp.up_proj.weight_scale_inv",
+ ]
+
+ @property
+ def tilert_tensor_alias(self) -> list[str]:
+ return self.tilert_weights_alias()
+
+ def get_weights_list(self) -> list[torch.Tensor]:
+ """
+ Get the weights list.
+
+ Returns:
+ List of weights.
+ """
+ return [self.tilert_norm_gamma, self.tilert_weights, self.tilert_scales]
+
+ def device_sharding(
+ self,
+ weights_dict: dict[str, torch.Tensor],
+ key_prefix: str,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Device sharding.
+
+ Args:
+ weights_dict: Dictionary of weights.
+
+ Returns:
+ Tuple of weights.
+ """
+ rmsnorm_gamma_key = f"{key_prefix}.post_attention_layernorm.weight"
+ if ".mlp" in key_prefix:
+ key_prefix_without_mlp = key_prefix.replace(".mlp", "")
+ rmsnorm_gamma_key = f"{key_prefix_without_mlp}.post_attention_layernorm.weight"
+ elif key_prefix == "mlp":
+ rmsnorm_gamma_key = "post_attention_layernorm.weight"
+ rmsnorm_gamma = weights_dict[rmsnorm_gamma_key]
+ rmsnorm_gamma = rmsnorm_gamma[None, :].repeat(self.num_devices, 1)
+
+ gate_weights, gate_scales, up_weights, up_scales = (
+ ExpertSelectUpGateSiLU.process_gate_up_weights(
+ key_prefix,
+ weights_dict,
+ self.num_devices,
+ )
+ )
+ gate_weights = gate_weights.reshape(self.n_experts, self.num_devices, -1, self.dim)
+ gate_weights = gate_weights.transpose(0, 1)
+ gate_scales = gate_scales.reshape(
+ self.n_experts, self.num_devices, -1, self.dim // self.block_size
+ )
+ gate_scales = gate_scales.transpose(0, 1)
+ up_weights = up_weights.reshape(self.n_experts, self.num_devices, -1, self.dim)
+ up_weights = up_weights.transpose(0, 1)
+ up_scales = up_scales.reshape(
+ self.n_experts, self.num_devices, -1, self.dim // self.block_size
+ )
+ up_scales = up_scales.transpose(0, 1)
+ return (
+ rmsnorm_gamma.contiguous(),
+ gate_weights.contiguous(),
+ gate_scales.contiguous(),
+ up_weights.contiguous(),
+ up_scales.contiguous(),
+ )
+
+ def init_reference_weights(
+ self,
+ state_dict: dict[str, torch.Tensor],
+ key_prefix: str,
+ device_id: int = 0,
+ ) -> None:
+ """
+ Initialize the reference weights.
+
+ Args:
+ state_dict: State dictionary.
+ device_id: Device ID.
+ """
+ sharded_list = self.device_sharding(state_dict, key_prefix)
+
+ gamma = sharded_list[0][device_id]
+ gate_weights = sharded_list[1][device_id]
+ gate_scales = sharded_list[2][device_id]
+ up_weights = sharded_list[3][device_id]
+ up_scales = sharded_list[4][device_id]
+ self.ref_norm_gamma = gamma
+ ref_gate_list = [
+ weight_dequant(gate_weights, gate_scales)
+ for gate_weights, gate_scales in zip(gate_weights, gate_scales)
+ ]
+ ref_up_list = [
+ weight_dequant(up_weights, up_scales)
+ for up_weights, up_scales in zip(up_weights, up_scales)
+ ]
+ self.ref_gate = torch.stack(ref_gate_list, dim=0)
+ self.ref_up = torch.stack(ref_up_list, dim=0)
+
+ def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
+ """
+ Initialize the tilert weights.
+
+ Args:
+ state_dict: State dictionary.
+ """
+ assert self.algorithm is not None, "Algorithm is not set"
+ self.tilert_norm_gamma, self.tilert_weights = RMSNormUpGateSiLUWeightsConverter(
+ self.model_args, self.num_devices
+ ).dispatch(self.algorithm, [state_dict[alias] for alias in self.tilert_weights_alias()])
+
+ def init_tilert_vars(self, batch_size: int, seq_len: int, dev_id: int = 0) -> None:
+ """
+ Initialize the tilert variables.
+
+ Args:
+ batch_size: Batch size.
+ seq_len: Sequence length.
+ """
+ self.hidden_out = torch.zeros(
+ (
+ batch_size,
+ seq_len,
+ self.n_experts,
+ self.moe_inter_dim_per_device,
+ ),
+ dtype=torch.bfloat16,
+ device=f"cuda:{dev_id}",
+ )
+
+ self.profile_logs = get_profile_log_tensor(device=f"cuda:{dev_id}")
+ self.is_init = True
+
+ def init_random_weights(self, dev_id: int = 0) -> None:
+ """
+ Initialize the random weights.
+
+ Returns:
+ None
+ """
+ gamma = torch.randn(self.dim, dtype=torch.float32, device=f"cuda:{dev_id}")
+ gate_weights = torch.randn(
+ self.inter_dim, self.dim, dtype=torch.bfloat16, device=f"cuda:{dev_id}"
+ ).to(torch.float8_e4m3fn)
+ up_weights = torch.randn(
+ self.inter_dim, self.dim, dtype=torch.bfloat16, device=f"cuda:{dev_id}"
+ ).to(torch.float8_e4m3fn)
+ inter_dim_scale_dim = self.inter_dim // self.block_size
+ dim_scale_dim = self.dim // self.block_size
+ scale_dtype = torch.float32 if self.arch_name == "glm_5" else torch.bfloat16
+ gate_scales = torch.randn(
+ inter_dim_scale_dim, dim_scale_dim, dtype=scale_dtype, device=f"cuda:{dev_id}"
+ )
+ up_scales = torch.randn(
+ inter_dim_scale_dim, dim_scale_dim, dtype=scale_dtype, device=f"cuda:{dev_id}"
+ )
+ tensor_list = [
+ gamma,
+ gate_weights,
+ gate_scales,
+ up_weights,
+ up_scales,
+ ]
+ state_dict = dict(zip(self.ref_tensor_alias, tensor_list))
+ self.init_reference_weights(state_dict, "mlp", dev_id)
+ sharded_list = self.device_sharding(state_dict, "mlp")
+ sharded_state_dict = {
+ alias: sharded_list[i][dev_id] for i, alias in enumerate(self.tilert_weights_alias())
+ }
+ self.init_tilert_weights(sharded_state_dict)
+
+ def golden_forward(
+ self,
+ x_in: torch.Tensor,
+ ) -> torch.Tensor:
+ assert self.ref_gate is not None
+ assert self.ref_up is not None
+ bsz = x_in.shape[0]
+ seq_len = x_in.shape[1]
+ assert bsz == 1
+ x_in_rmsnorm = torch.nn.functional.rms_norm(
+ x_in.float(), [x_in.size(-1)], self.ref_norm_gamma, self.eps
+ )
+ hidden_out_list = []
+ for s in range(seq_len):
+ hidden_out_w1_list = []
+ hidden_out_w3_list = []
+
+ for i in range(self.n_experts):
+ hidden_out_w1_sel = x_in_rmsnorm[0, s].float() @ self.ref_gate[i].float().T
+ hidden_out_w3_sel = x_in_rmsnorm[0, s].float() @ self.ref_up[i].float().T
+ hidden_out_w1_list.append(hidden_out_w1_sel)
+ hidden_out_w3_list.append(hidden_out_w3_sel)
+ hidden_out_w1 = torch.stack(hidden_out_w1_list, dim=0)
+ hidden_out_w3 = torch.stack(hidden_out_w3_list, dim=0)
+ hidden_out = F.silu(hidden_out_w1.float()) * hidden_out_w3.float()
+ hidden_out = hidden_out.to(torch.bfloat16)
+ hidden_out_list.append(hidden_out)
+ hidden_out = torch.stack(hidden_out_list, dim=0)
+ hidden_out = hidden_out[None, ...]
+ return hidden_out
+
+ def tilert_forward(
+ self,
+ x_in: torch.Tensor,
+ ) -> torch.Tensor:
+ assert self.rmsnorm_up_gate_silu_func is not None
+ assert self.algorithm is not None, "Algorithm is not set"
+ self.rmsnorm_up_gate_silu_func(
+ x_in,
+ self.tilert_norm_gamma,
+ self.tilert_weights,
+ self.hidden_out,
+ self.profile_logs,
+ model_arch=self.model_args.arch_name,
+ compute_kernel_type=self.algorithm.value,
+ )
+ return self.hidden_out
+
+ def __call__(
+ self,
+ x_in: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.golden_forward(x_in)
diff --git a/python/models/deepseek_v3_2/ops/rotate.py b/tilert/models/glm_5/_dsa_v32/ops/rotate.py
similarity index 79%
rename from python/models/deepseek_v3_2/ops/rotate.py
rename to tilert/models/glm_5/_dsa_v32/ops/rotate.py
index 539f334..10a46f1 100644
--- a/python/models/deepseek_v3_2/ops/rotate.py
+++ b/tilert/models/glm_5/_dsa_v32/ops/rotate.py
@@ -1,12 +1,14 @@
+"""Rotate(hadamard transform) operation module."""
+
from dataclasses import dataclass
+from enum import Enum
import torch
import torch.nn.functional as F
from tilert.models.base import TileRTModule
-from tilert.models.deepseek_v3_2.model_args import ModelArgs
+from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
from tilert.models.utils import apply_rotary_emb
-from tilert.profiler.utils import parse_profile_log_tensor
from tilert.utils import get_profile_log_tensor
try:
@@ -61,45 +63,31 @@ def rotate(
output_raw: torch.Tensor,
freqs_cis_raw: torch.Tensor,
profile_logs: torch.Tensor,
+ model_arch: str,
+ compute_kernel_type: str = "general",
) -> None:
"""
Rotate (hadamard transform) operation.
- Unified for deepseek_v3_2 (64 heads) and glm_5 (32 heads). Dispatches by
- input_raw.shape[-2]: 64 -> rotate_op, 32 -> rotate_glm5_op.
-
Args:
input_raw (torch.Tensor): The input tensor [..., head, 128].
output_raw (torch.Tensor): The output tensor where the result will be stored.
freqs_cis_raw (torch.Tensor): The frequency tensor.
profile_logs (torch.Tensor): Tensor for storing profiling logs.
+ model_arch: Model architecture string.
+ compute_kernel_type: Compute kernel type string.
Returns:
None
"""
- if input_raw.dtype != torch.bfloat16:
- raise ValueError("input must be a bfloat16 tensor.")
-
- if output_raw.dtype != torch.bfloat16:
- raise ValueError("output must be a bfloat16 tensor.")
-
- if freqs_cis_raw.dtype != torch.float32:
- raise ValueError("freqs_cis must be a float32 tensor.")
-
- head = input_raw.shape[-2]
- dim = input_raw.shape[-1]
- if dim != 128:
- raise ValueError("dim must be 128, as we precompute scale inner kernel")
-
- if head == 64:
- torch.ops.tilert.rotate_op(input_raw, output_raw, freqs_cis_raw, profile_logs)
- elif head == 32:
- torch.ops.tilert.rotate_glm5_op(input_raw, output_raw, freqs_cis_raw, profile_logs)
- else:
- raise ValueError(
- f"Unsupported head size: {head}. Rotate op supports "
- "index_n_heads=64 (deepseek_v3_2) or 32 (glm_5)."
- )
+ torch.ops.tilert.rotate_op(
+ input_raw,
+ output_raw,
+ freqs_cis_raw,
+ model_arch,
+ compute_kernel_type,
+ profile_logs,
+ )
@dataclass
@@ -126,6 +114,12 @@ def __call__(self) -> list[str]:
return self.tilert_tensor_alias
+class RotateAlgorithm(Enum):
+ """Rotate algorithm."""
+
+ GENERAL = "general"
+
+
class Rotate(TileRTModule):
"""Rotate module: RoPE on first qk_rope_head_dim dims + hadamard transform.
@@ -133,6 +127,11 @@ class Rotate(TileRTModule):
No weights; uses model_args for dimensions.
"""
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [RotateAlgorithm.GENERAL],
+ "glm_5": [RotateAlgorithm.GENERAL],
+ }
+
def __init__(
self,
model_args: ModelArgs,
@@ -202,9 +201,11 @@ def tilert_forward(self, idx_q: torch.Tensor, freqs_cis: torch.Tensor) -> torch.
assert self.output is not None
assert self.profile_logs is not None
freqs_cis_real = torch.view_as_real(freqs_cis).reshape(*freqs_cis.shape[:-1], -1)
- rotate(idx_q, self.output, freqs_cis_real, self.profile_logs)
- if self.flag_enable_profiling_log:
- parse_profile_log_tensor(
- self.profile_logs, self.get_profile_log_path(), [(self.op_name, 0.0)]
- )
+ rotate(
+ idx_q,
+ self.output,
+ freqs_cis_real,
+ self.profile_logs,
+ model_arch=self.model_args.arch_name,
+ )
return self.output
diff --git a/tilert/models/glm_5/_dsa_v32/ops/sparse_index.py b/tilert/models/glm_5/_dsa_v32/ops/sparse_index.py
new file mode 100644
index 0000000..ca69c49
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/ops/sparse_index.py
@@ -0,0 +1,135 @@
+"""Sparse index operation module."""
+
+import torch
+
+__all__ = [
+ "sparse_index",
+ "sparse_index_topk",
+]
+
+
+def sparse_index(
+ q: torch.Tensor, # noqa: VNE001
+ kv: torch.Tensor,
+ weights: torch.Tensor,
+ logits: torch.Tensor,
+ cur_pos: int,
+ profile_logs: torch.Tensor,
+ compute_kernel_type: str = "bf16",
+ *,
+ model_arch: str,
+) -> None:
+ """
+ Sparse index operation.
+
+ Calculate sparse index using q * kv * weights.
+
+ Args:
+ q (torch.Tensor): The query tensor.
+ kv (torch.Tensor): The key-value tensor.
+ weights (torch.Tensor): The weights tensor.
+ logits (torch.Tensor): The logits tensor.
+ cur_pos (int): The position of the first token.
+ profile_logs (torch.Tensor): Tensor for storing profiling logs.
+ compute_kernel_type (str): Kernel type ("bf16").
+ model_arch (str): Model architecture ("deepseek_v3_2").
+
+ Returns:
+ None
+ """
+ if q.dtype != torch.bfloat16:
+ raise ValueError("input must be a bfloat16 tensor.")
+ if kv.dtype != torch.bfloat16:
+ raise ValueError("kv must be a bfloat16 tensor.")
+ if weights.dtype != torch.bfloat16:
+ raise ValueError("weights must be a bfloat16 tensor.")
+ if logits.dtype != torch.float32:
+ raise ValueError("logits must be a float32 tensor.")
+
+ head = q.shape[-2]
+ dim = q.shape[-1]
+
+ if head != 64 and head != 32:
+ raise ValueError(
+ f"Unsupported head size: {head}. Sparse index op currently only \
+ supports a head number of 64 or 32."
+ )
+ if dim != 128:
+ raise ValueError("dim must be 128, as we precompute scale inner kernel")
+
+ device = q.device
+ if any(t.device != device for t in (kv, weights, logits, profile_logs)):
+ raise ValueError(
+ "sparse_index inputs must be on the same device: "
+ f"q={device}, kv={kv.device}, weights={weights.device}, "
+ f"logits={logits.device}, profile_logs={profile_logs.device}"
+ )
+ if model_arch == "deepseek_v3_2" and head == 32:
+ model_arch = "glm_5"
+ torch.ops.tilert.sparse_index_op(
+ q, kv, weights, logits, cur_pos, model_arch, compute_kernel_type, profile_logs
+ )
+
+
+def sparse_index_topk(
+ q: torch.Tensor, # noqa: VNE001
+ kv: torch.Tensor,
+ weights: torch.Tensor,
+ logits: torch.Tensor,
+ indices: torch.Tensor,
+ cur_pos: int,
+ profile_logs: torch.Tensor,
+) -> None:
+ """
+ Sparse index operation.
+
+ Calculate sparse index using q * kv * weights.
+
+ Args:
+ q (torch.Tensor): The query tensor.
+ kv (torch.Tensor): The key-value tensor.
+ weights (torch.Tensor): The weights tensor.
+ logits (torch.Tensor): The logits tensor.
+ cur_pos (int): The position of the first token.
+ profile_logs (torch.Tensor): Tensor for storing profiling logs.
+
+ Returns:
+ None
+ """
+ if q.dtype != torch.bfloat16:
+ raise ValueError("input must be a bfloat16 tensor.")
+ if kv.dtype != torch.bfloat16:
+ raise ValueError("kv must be a bfloat16 tensor.")
+ if weights.dtype != torch.bfloat16:
+ raise ValueError("weights must be a bfloat16 tensor.")
+ if logits.dtype != torch.float32:
+ raise ValueError("logits must be a float32 tensor.")
+
+ seqlen = q.shape[-3]
+ head = q.shape[-2]
+ dim = q.shape[-1]
+
+ if head not in (32, 64):
+ raise ValueError(
+ f"Unsupported head size: {head}. Sparse index topk fused op "
+ "supports head number of 32 (GLM5) or 64 (DSV3.2)."
+ )
+ if dim != 128:
+ raise ValueError("dim must be 128, as we precompute scale inner kernel")
+
+ device = q.device
+ if any(t.device != device for t in (kv, weights, logits, indices, profile_logs)):
+ raise ValueError(
+ "sparse_index inputs must be on the same device: "
+ f"q={device}, kv={kv.device}, weights={weights.device}, "
+ f"logits={logits.device}, profile_logs={profile_logs.device}"
+ )
+ workspace = torch.zeros(seqlen, (200 * 1024 + 260), dtype=torch.int32, device=device)
+ if head == 64:
+ torch.ops.tilert.sparse_index_topk_dsv32_op(
+ q, kv, weights, logits, cur_pos, indices, workspace, profile_logs
+ )
+ else:
+ torch.ops.tilert.sparse_index_topk_glm5_op(
+ q, kv, weights, logits, cur_pos, indices, workspace, profile_logs
+ )
diff --git a/python/models/deepseek_v3_2/ops/topk.py b/tilert/models/glm_5/_dsa_v32/ops/topk.py
similarity index 75%
rename from python/models/deepseek_v3_2/ops/topk.py
rename to tilert/models/glm_5/_dsa_v32/ops/topk.py
index bb41575..bb9dfbb 100644
--- a/python/models/deepseek_v3_2/ops/topk.py
+++ b/tilert/models/glm_5/_dsa_v32/ops/topk.py
@@ -1,10 +1,18 @@
"""topk operations module."""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
import torch
import torch.nn as nn
from tilert.utils import get_profile_log_tensor
+if TYPE_CHECKING:
+ from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
+
+
__all__ = [
"TopK",
"topk_approximate",
@@ -17,6 +25,8 @@ def topk_approximate(
seq_len: int,
topk: int,
profile_logs: torch.Tensor,
+ model_arch: str,
+ compute_kernel_type: str = "general",
) -> torch.Tensor:
"""
Topk approximate operation.
@@ -42,7 +52,9 @@ def topk_approximate(
raise ValueError("batch must be 1 in this version")
indices = torch.zeros(batch, topk, dtype=torch.int32, device=logits.device)
- torch.ops.tilert.topk_approximate_op(logits, indices, seq_len, profile_logs)
+ torch.ops.tilert.topk_approximate_op(
+ logits, indices, seq_len, model_arch, compute_kernel_type, profile_logs
+ )
return indices
@@ -52,6 +64,8 @@ def topk_accurate(
seq_len: int,
topk: int,
profile_logs: torch.Tensor,
+ model_arch: str,
+ compute_kernel_type: str = "general",
) -> torch.Tensor:
"""
Topk approximate operation.
@@ -71,8 +85,8 @@ def topk_accurate(
if logits.dtype != torch.float32:
raise ValueError("logits must be a float32 tensor.")
- if topk != 2048:
- raise ValueError("topk must be 2048.")
+ if topk not in (512, 2048):
+ raise ValueError("topk must be 512 or 2048.")
assert logits.shape[0] == 1, "batch must be 1 in this version"
num_samples = logits.shape[1]
@@ -80,7 +94,13 @@ def topk_accurate(
indices = torch.zeros(num_samples, topk, dtype=torch.int32, device=logits.device)
indices_ws = torch.zeros(1, num_samples, 4, topk * 2, dtype=torch.int32, device=logits.device)
torch.ops.tilert.topk_accurate_op(
- logits, indices, seq_len - num_samples, indices_ws, profile_logs
+ logits,
+ indices,
+ seq_len - num_samples,
+ indices_ws,
+ model_arch,
+ compute_kernel_type,
+ profile_logs,
)
return indices
@@ -93,9 +113,14 @@ class TopK(nn.Module):
(reference implementation) and tilert_forward (TileRT kernel).
"""
- def __init__(self, use_approximate: bool = False) -> None:
+ def __init__(self, use_approximate: bool = False, model_args: ModelArgs | None = None) -> None:
super().__init__()
self.use_approximate = use_approximate
+ if model_args is None:
+ from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
+
+ model_args = ModelArgs()
+ self.model_args = model_args
def golden_forward(
self,
@@ -131,9 +156,13 @@ def tilert_forward(
profile_logs = get_profile_log_tensor(device=logits.device)
cache_len = logits.shape[-1]
if self.use_approximate:
- indices = topk_approximate(logits, cache_len, topk, profile_logs)
+ indices = topk_approximate(
+ logits, cache_len, topk, profile_logs, model_arch=self.model_args.arch_name
+ )
else:
- indices = topk_accurate(logits, cache_len, topk, profile_logs)
+ indices = topk_accurate(
+ logits, cache_len, topk, profile_logs, model_arch=self.model_args.arch_name
+ )
if indices.dim() == 2:
return indices.unsqueeze(0)
return indices
diff --git a/python/models/deepseek_v3_2/ops/unproj_o_allreduce.py b/tilert/models/glm_5/_dsa_v32/ops/unproj_o_allreduce.py
similarity index 56%
rename from python/models/deepseek_v3_2/ops/unproj_o_allreduce.py
rename to tilert/models/glm_5/_dsa_v32/ops/unproj_o_allreduce.py
index 50b413f..257acf5 100644
--- a/python/models/deepseek_v3_2/ops/unproj_o_allreduce.py
+++ b/tilert/models/glm_5/_dsa_v32/ops/unproj_o_allreduce.py
@@ -1,5 +1,6 @@
"""UnprojOAllreduce operation module."""
+import math
from dataclasses import dataclass
from enum import Enum
@@ -7,13 +8,13 @@
from tilert.models.base import TileRTModule, TilertWeightsConverter
from tilert.models.common import weight_dequant
-from tilert.models.deepseek_v3_2.model_args import ModelArgs
-from tilert.profiler.utils import parse_profile_log_tensor
+from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
from tilert.utils import get_profile_log_tensor
__all__ = [
"unproj_o_allreduce",
"UnProjOAllReduce",
+ "UnProjOAllReduceAlgorithm",
"UnProjOAllReduceRefWeightsAlias",
"UnProjOAllReduceTilertWeightsAlias",
]
@@ -27,7 +28,8 @@ def unproj_o_allreduce(
flag: int,
vec_out: torch.Tensor,
profile_logs: torch.Tensor,
- algorithm: str = "fp8mma",
+ model_arch: str,
+ compute_kernel_type: str = "bf16",
) -> None:
"""
Fused operation of unprojection and allreduce.
@@ -39,29 +41,26 @@ def unproj_o_allreduce(
x_in: Input tensor.
flag: Input flag.
vec_out: Output tensor.
- profile_logs: Profile logs tensor. This is a 1D tensor of shape
- (num_sms,) to store the profile logs of the unproj_o_allreduce
- operation, where num_sms is the number of SMs on the
- device.
+ profile_logs: Profile logs tensor.
+ model_arch: Model architecture ("deepseek_v3_2" or "glm_5").
+ compute_kernel_type: Compute kernel type ("bf16", "fp16mma").
"""
- if vec_out.shape[-1] == 7168:
- assert algorithm == "fp8mma", "Only fp8mma is supported for deepseek_v3_2"
- torch.ops.tilert.unproj_o_allreduce_op(
- vec_in, mat_in, mat_scale, x_in, flag, vec_out, profile_logs
- )
-
- elif vec_out.shape[-1] == 6144:
- torch.ops.tilert.unproj_o_allreduce_glm5_op(
- vec_in, mat_in, mat_scale, x_in, flag, vec_out, profile_logs, algorithm
- )
- else:
- raise ValueError(f"Unsupported vector dimension: {vec_out.shape[-1]}")
+ torch.ops.tilert.unproj_o_allreduce_op(
+ vec_in,
+ mat_in,
+ mat_scale,
+ x_in,
+ flag,
+ vec_out,
+ profile_logs,
+ model_arch,
+ compute_kernel_type,
+ )
class UnProjOAllReduceAlgorithm(Enum):
"""UnprojOAllReduce algorithm"""
- FP8MMA = "fp8mma"
FP16MMA = "fp16mma"
@@ -99,83 +98,102 @@ class UnProjOAllReduceWeightsConverter(TilertWeightsConverter):
"""UnProjOAllReduce weights converter"""
@staticmethod
- def _swizzle_qmma_16x32(mat_in: torch.Tensor) -> torch.Tensor:
- assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 32
- assert mat_in.dtype == torch.float8_e4m3fn
- # PTX isa fig.88
+ def _swizzle_mma_16x16(mat_in: torch.Tensor) -> torch.Tensor:
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 16
pre_shape = mat_in.shape[:-2]
- mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 4).transpose(-4, -3).transpose(-5, -4)
- return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 4).transpose(-3, -2)
+ mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 2).transpose(-4, -3).transpose(-5, -4)
+ return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 2).transpose(-3, -2)
- def convert_to_fp8mma(
- self, weights_list: list[torch.Tensor]
+ def convert_to_fp16mma_128cta(
+ self,
+ weights_list: list[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Convert the weights to fp8mma format.
+ """Convert weights to the packed kernel layout (GLM5 or DSV3.2)."""
+ with torch.inference_mode():
+ mat, scales = weights_list
+ if scales.dtype != torch.float32:
+ scales = scales.to(torch.float32)
- Args:
- weights_list: List of weights.
+ dim = self.model_args.dim
+ block_size = self.model_args.block_size
+ sms = 128
+ vec_dim = mat.shape[-1]
+ dim_per_sm = dim // sms
+ full_tiles = dim_per_sm // 16
+ remainder_rows = dim_per_sm % 16
+ stages = vec_dim // 512
+ vec_scale_dim = vec_dim // block_size
+ scale_per_stage = vec_scale_dim // stages
+
+ dim_scale_dim = dim // block_size
+ scales_per_full_tile = 2 if remainder_rows > 0 else 1
+ rem_scales = 1 if remainder_rows > 0 else 0
+ total_scale_slots = (full_tiles * scales_per_full_tile + rem_scales) * scale_per_stage
+ repeat_factor = 8 if remainder_rows == 0 else 16
+
+ sc = scales.reshape(dim_scale_dim, 1, vec_scale_dim)
+ sc = sc.repeat(1, repeat_factor, 1)
+ scales_per_cta = full_tiles * scales_per_full_tile + rem_scales
+ sc = (
+ sc.reshape(sms, scales_per_cta, stages, scale_per_stage)
+ .transpose(1, 2)
+ .reshape(sms, stages, total_scale_slots)
+ .view(torch.float8_e4m3fn)
+ )
+ sc_packed = sc
- Returns:
- Tuple of weights.
- """
- args = self.model_args
- assert args.arch_name == "deepseek_v3_2" or args.arch_name == "glm_5"
- arch_name = args.arch_name
- dim = args.dim
- num_sms = 128
- if arch_name == "deepseek_v3_2":
- num_sms = 112
- dim_per_sm = dim // num_sms
- dim_scale_dim = dim // args.block_size
+ mat_per_sm = mat.reshape(sms, dim_per_sm, vec_dim)
- with torch.inference_mode():
- mat_in, scales_trt = weights_list
- vec_dim = mat_in.shape[-1] # 2048 for both deepseek_v3_2 and glm_5
- assert scales_trt.shape == (dim // args.block_size, vec_dim // args.block_size)
+ full_rows = full_tiles * 16
+ mat_full = (
+ mat_per_sm[:, :full_rows, :]
+ .reshape(sms, full_tiles, 16, stages, 512)
+ .transpose(2, 3)
+ .reshape(sms, full_tiles, stages, 16, 32, 16)
+ .transpose(3, 4)
+ .reshape(sms, full_tiles, stages, 32, 16, 16)
+ )
+ mat_full = UnProjOAllReduceWeightsConverter._swizzle_mma_16x16(mat_full)
+ mat_full = mat_full.transpose(1, 2).reshape(sms, stages, -1)
- weights_trt = mat_in.reshape(num_sms, dim_per_sm, vec_dim)
- # dim_per_stage is 512
- stages = vec_dim // 512
- weights_trt = weights_trt.reshape(num_sms, dim_per_sm, stages, 512).transpose(1, 2)
-
- weights_trt = weights_trt.reshape(
- num_sms, stages, dim_per_sm // 16, 16, 16, 32
- ).transpose(-2, -3)
- weights_trt = self._swizzle_qmma_16x32(weights_trt)
- weights_trt = weights_trt.reshape(num_sms, stages, -1)
-
- if arch_name == "glm_5":
- if scales_trt.dtype != torch.float32:
- print(
- "Warning: UnProjOAllReduceWeightsConverter: "
- + f"scales_trt.dtype: {scales_trt.dtype} "
- + "is not float32, convert to float32."
- )
- scales_trt = scales_trt.to(torch.float32)
- # repeat 8 times
- scales_trt = (
- scales_trt.reshape((dim_scale_dim, 1, -1)).repeat(1, 8, 1).reshape(num_sms, -1)
+ if remainder_rows > 0:
+ mat_rem_raw = mat_per_sm[:, full_rows:, :]
+ mat_rem_padded = torch.zeros(
+ sms, 16, vec_dim, dtype=mat_rem_raw.dtype, device=mat_rem_raw.device
)
- else: # DS v3.2, use bfloat16 for scales
- scales_trt = scales_trt.to(torch.bfloat16)
-
- return weights_trt.contiguous(), scales_trt.contiguous()
+ mat_rem_padded[:, :remainder_rows, :] = mat_rem_raw
+ mat_rem = (
+ mat_rem_padded.reshape(sms, 1, 16, stages, 512)
+ .transpose(2, 3)
+ .reshape(sms, 1, stages, 16, 32, 16)
+ .transpose(3, 4)
+ .reshape(sms, 1, stages, 32, 16, 16)
+ )
+ mat_rem = UnProjOAllReduceWeightsConverter._swizzle_mma_16x16(mat_rem)
+ mat_rem = mat_rem.transpose(1, 2).reshape(sms, stages, -1)
+ mat_combined = torch.cat([mat_full, mat_rem], dim=-1)
+ else:
+ mat_combined = mat_full
- @staticmethod
- def _swizzle_mma_16x16(mat_in: torch.Tensor) -> torch.Tensor:
- assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 16
- # PTX isa fig.88
- pre_shape = mat_in.shape[:-2]
- mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 2).transpose(-4, -3).transpose(-5, -4)
- return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 2).transpose(-3, -2)
+ scales_padding = torch.zeros(
+ sms,
+ stages,
+ 128 - sc_packed.shape[-1],
+ dtype=torch.float8_e4m3fn,
+ device=mat.device,
+ )
+ mat_all = torch.cat([mat_combined, sc_packed, scales_padding], dim=-1).contiguous()
+ dummy_scales = torch.zeros(1, dtype=torch.float32, device=mat.device)
+ return mat_all, dummy_scales
def convert_to_fp16mma(
self,
weights_list: list[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
"""Convert common weights to TileRT FP16 MMA layout."""
- assert self.model_args.arch_name == "glm_5", "Only GLM-5 supports FP16 MMA"
+ if self.model_args.arch_name == "deepseek_v3_2":
+ return self.convert_to_fp16mma_128cta(weights_list)
+ assert self.model_args.arch_name == "glm_5", "Only GLM-5 and DSV3.2 support FP16 MMA"
with torch.inference_mode():
mat, scales = weights_list
@@ -187,32 +205,44 @@ def convert_to_fp16mma(
)
scales = scales.to(torch.float32)
- sms = 128 # use 128 sms for glm_5
- pages = 4
- scales = scales.reshape(6144 // 128, 1, 2048 // 128)
+ dim = self.model_args.dim
+ block_size = self.model_args.block_size
+ sms = 128
+ vec_dim = mat.shape[-1]
+ dim_per_sm = dim // sms
+ tiles_per_stage = dim_per_sm // 16
+ stages = vec_dim // 512
+ dim_scale_dim = dim // block_size
+ vec_scale_dim = vec_dim // block_size
+ scale_per_stage = vec_scale_dim // stages
+
+ scales = scales.reshape(dim_scale_dim, 1, vec_scale_dim)
scales = scales.repeat(1, 8, 1)
- scales = scales.reshape(128, 3, 4, 4).transpose(1, 2)
- # to 128, 4, 12x4
- scales = scales.reshape(128, 4, 12).view(torch.float8_e4m3fn)
+ scales = (
+ scales.reshape(sms, tiles_per_stage, stages, scale_per_stage)
+ .transpose(1, 2)
+ .reshape(sms, stages, tiles_per_stage * scale_per_stage)
+ .view(torch.float8_e4m3fn)
+ )
mat = (
- mat.reshape(128, 48, 2048)
- .reshape(128, 3, 16, 4, 512)
+ mat.reshape(sms, dim_per_sm, vec_dim)
+ .reshape(sms, tiles_per_stage, 16, stages, 512)
.transpose(2, 3)
- .reshape(128, 3, 4, 16, 32, 16)
+ .reshape(sms, tiles_per_stage, stages, 16, 32, 16)
.transpose(3, 4)
- .reshape(128, 3, 4, 32, 16, 16)
+ .reshape(sms, tiles_per_stage, stages, 32, 16, 16)
)
mat = UnProjOAllReduceWeightsConverter._swizzle_mma_16x16(mat)
- mat = mat.transpose(1, 2).reshape(128, 4, -1)
+ mat = mat.transpose(1, 2).reshape(sms, stages, -1)
scales_padding = torch.zeros(
sms,
- pages,
+ stages,
128 - scales.shape[-1],
dtype=torch.float8_e4m3fn,
device=mat.device,
- ) # append 128-byte aligned scale: (128, 4, 24704) for glm_5
+ )
mat_full = torch.cat([mat, scales, scales_padding], dim=-1).contiguous()
dummy_scales = torch.zeros(1, dtype=torch.float32, device=mat.device)
return mat_full, dummy_scales
@@ -221,6 +251,15 @@ def convert_to_fp16mma(
class UnProjOAllReduce(TileRTModule):
"""UnProjOAllReduce module"""
+ _SUPPORTED_ALGORITHMS = {
+ "deepseek_v3_2": [
+ UnProjOAllReduceAlgorithm.FP16MMA,
+ ],
+ "glm_5": [
+ UnProjOAllReduceAlgorithm.FP16MMA,
+ ],
+ }
+
def __init__(
self,
model_args: ModelArgs,
@@ -228,7 +267,7 @@ def __init__(
device_id: int = 0,
ref_weights_alias: UnProjOAllReduceRefWeightsAlias | None = None,
tilert_weights_alias: UnProjOAllReduceTilertWeightsAlias | None = None,
- algorithm: UnProjOAllReduceAlgorithm = UnProjOAllReduceAlgorithm.FP8MMA,
+ algorithm: UnProjOAllReduceAlgorithm = UnProjOAllReduceAlgorithm.FP16MMA,
):
super().__init__(
self.__class__.__name__,
@@ -253,17 +292,22 @@ def __init__(
self.n_heads = self.model_args.n_heads
self.head_dim = self.model_args.v_head_dim
+ if self.n_heads % self.num_devices == 0:
+ self.num_local_heads = self.n_heads // self.num_devices
+ else:
+ n_local = math.ceil(self.n_heads / self.num_devices)
+ if n_local % 2 != 0:
+ n_local += 1
+ self.num_local_heads = n_local
+
self.block_size = self.model_args.block_size
self.algorithm: UnProjOAllReduceAlgorithm = algorithm
- # reference weights
self.ref_unproj_o: torch.Tensor | None = None
- # tilert weights
self.tilert_weights: torch.Tensor | None = None
self.tilert_scales: torch.Tensor | None = None
- # tilert vars
self.hidden_out: torch.Tensor | None = None
self.profile_logs: torch.Tensor | None = None
@@ -290,10 +334,55 @@ def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, tor
"""
unproj_o_weight = weights_map[self.ref_weights_alias.o_proj_weight]
unproj_o_scale = weights_map[self.ref_weights_alias.o_proj_scale_inv]
- unproj_o_weight = unproj_o_weight.reshape(self.dim, self.num_devices, -1)
- unproj_o_weight = unproj_o_weight.transpose(0, 1)
- unproj_o_scale = unproj_o_scale.reshape(self.dim // self.block_size, self.num_devices, -1)
- unproj_o_scale = unproj_o_scale.transpose(0, 1)
+
+ if self.n_heads % self.num_devices == 0:
+ unproj_o_weight = unproj_o_weight.reshape(self.dim, self.num_devices, -1)
+ unproj_o_weight = unproj_o_weight.transpose(0, 1)
+ unproj_o_scale = unproj_o_scale.reshape(
+ self.dim // self.block_size, self.num_devices, -1
+ )
+ unproj_o_scale = unproj_o_scale.transpose(0, 1)
+ else:
+ cols_per_head = self.head_dim
+ cols_per_dev = self.num_local_heads * cols_per_head
+ W = unproj_o_weight.view(self.dim, self.n_heads, cols_per_head)
+
+ scale_cols_per_head = cols_per_head // self.block_size
+ scale_cols_per_dev = self.num_local_heads * scale_cols_per_head
+ S = unproj_o_scale.view(self.dim // self.block_size, self.n_heads, scale_cols_per_head)
+
+ W_devs = []
+ S_devs = []
+ for dev in range(self.num_devices):
+ start = dev * self.num_local_heads
+ end = min(self.n_heads, start + self.num_local_heads)
+ real = max(0, end - start)
+
+ dev_W = torch.zeros(
+ self.dim,
+ self.num_local_heads,
+ cols_per_head,
+ dtype=W.dtype,
+ device=W.device,
+ )
+ if real > 0:
+ dev_W[:, :real] = W[:, start:end]
+ W_devs.append(dev_W.reshape(self.dim, cols_per_dev))
+
+ dev_S = torch.zeros(
+ self.dim // self.block_size,
+ self.num_local_heads,
+ scale_cols_per_head,
+ dtype=S.dtype,
+ device=S.device,
+ )
+ if real > 0:
+ dev_S[:, :real] = S[:, start:end]
+ S_devs.append(dev_S.reshape(self.dim // self.block_size, scale_cols_per_dev))
+
+ unproj_o_weight = torch.stack(W_devs, dim=0)
+ unproj_o_scale = torch.stack(S_devs, dim=0)
+
return {
self.tilert_weights_alias.unproj_weights: unproj_o_weight.contiguous(),
self.tilert_weights_alias.unproj_scales: unproj_o_scale.contiguous(),
@@ -413,12 +502,9 @@ def tilert_forward(
flag,
self.hidden_out,
self.profile_logs,
- self.algorithm.value,
+ model_arch=self.model_args.arch_name,
+ compute_kernel_type=self.algorithm.value,
)
- if self.flag_enable_profiling_log:
- parse_profile_log_tensor(
- self.profile_logs, self.get_profile_log_path(), [(self.op_name, 0.0)]
- )
return self.hidden_out
def __call__(
diff --git a/tilert/models/glm_5/_dsa_v32/temp_var_indices.py b/tilert/models/glm_5/_dsa_v32/temp_var_indices.py
new file mode 100644
index 0000000..3a7af62
--- /dev/null
+++ b/tilert/models/glm_5/_dsa_v32/temp_var_indices.py
@@ -0,0 +1,118 @@
+"""Named indices for DSA temporary variables.
+
+Lets Python code reference temp_vars by name instead of magic numbers.
+
+Usage::
+
+ from tilert.models.glm_5._dsa_v32.temp_var_indices import Idx
+
+ token_out = intermediates[Idx.TOKEN_OUT] # equivalent to intermediates[25]
+"""
+
+from enum import IntEnum
+
+
+class DsaTempVarIdx(IntEnum):
+ """Index constants for DSA temp_vars."""
+
+ Q = 0
+ KV = 1
+ KI = 2
+ Q_NOPE_DOWN = 3
+ Q_PE = 4
+ IQ = 5
+ IQ_RT = 6
+ IDX_SCORES = 7
+ IDX_LOGITS = 8
+ IDX_SELECTS = 9
+ Q_NOPE = 10
+ O = 11 # noqa: E741
+ O_ACC = 12
+ O_LSE = 13
+ O_LSE_ACC = 14
+ PROJ_O = 15
+ UNPROJ_O = 16
+ SCORES = 17
+ X_MLP_IN = 18
+ UP_GATE = 19
+ SEL_PROBS = 20
+ SEL_INDICES = 21
+ EXP_OUT = 22
+ X_RMSNORM = 23
+ LOGITS_OUT = 24
+ TOKEN_OUT = 25
+ EMBEDDING_RMSNORM = 26
+ HIDDEN_RMSNORM = 27
+ EH_PROJ = 28
+ X_TENSOR = 29
+ ROPE_FREQS = 30
+ CUR_POS = 31
+ TOKEN_ID = 32
+ LAST_HIDDEN_STATES = 33
+ DRAFT_TOKENS = 34
+ PREDICTED_TOKENS = 35
+ PREDICTED_HIDDEN = 36
+ ACCEPTED_TOKENS = 37
+ NEXT_DRAFT_TOKENS = 38
+ X_QUANT = 39
+ X_SCALE = 40
+ MOE_UP_GATE = 41
+ IDX_SEL_WS = 42
+ MTP0_TOKEN_OUT = 43
+ MTP1_TOKEN_OUT = 44
+ MTP0_EXP_OUT = 45
+ SAMPLING_SEED = 46
+ SAMPLING_POSITIONS = 47
+ SAMPLING_CONFIG = 48
+ TOP_P_SCORES = 49
+ TOP_P_DEBUG = 50
+ LORA_SLOT_ID = 51
+ LORA_RANK = 52
+ TOP_N_LOG_PROBS = 53
+ TOP_N_INDICES = 54
+ LOGPROBS_FLAG = 55
+
+
+TEMP_VARS_SIZE = 56
+
+Idx = DsaTempVarIdx
+
+
+def validate_temp_vars_layout() -> None:
+ """Validate the temporary-variable index enum.
+
+ Checks:
+ 1. Enum member count equals TEMP_VARS_SIZE.
+ 2. Indices are contiguous 0..TEMP_VARS_SIZE-1 with no gaps or duplicates.
+ 3. (If the backend is loaded) the backend temp_vars_size matches TEMP_VARS_SIZE.
+
+ Raises:
+ RuntimeError: If any validation check fails.
+ """
+ members = list(DsaTempVarIdx)
+
+ if len(members) != TEMP_VARS_SIZE:
+ raise RuntimeError(
+ f"DsaTempVarIdx has {len(members)} members but TEMP_VARS_SIZE={TEMP_VARS_SIZE}"
+ )
+
+ indices = sorted(m.value for m in members)
+ expected = list(range(TEMP_VARS_SIZE))
+ if indices != expected:
+ missing = set(expected) - set(indices)
+ dupes = [i for i in indices if indices.count(i) > 1]
+ raise RuntimeError(
+ f"DsaTempVarIdx indices are not contiguous 0..{TEMP_VARS_SIZE - 1}. "
+ f"Missing: {missing}, Duplicates: {set(dupes)}"
+ )
+
+ try:
+ import torch
+
+ cpp_size = torch.ops.tilert.dsa_temp_vars_size()
+ if cpp_size != TEMP_VARS_SIZE:
+ raise RuntimeError(
+ f"TEMP_VARS_SIZE={TEMP_VARS_SIZE} != " f"backend temp_vars_size={cpp_size}"
+ )
+ except (AttributeError, RuntimeError):
+ pass
diff --git a/python/models/glm_5/generator.py b/tilert/models/glm_5/generator.py
similarity index 79%
rename from python/models/glm_5/generator.py
rename to tilert/models/glm_5/generator.py
index 8bc9757..b3e8ddd 100644
--- a/python/models/glm_5/generator.py
+++ b/tilert/models/glm_5/generator.py
@@ -7,10 +7,10 @@
from transformers import AutoTokenizer
from tilert import logger
-from tilert.models.deepseek_v3_2.generator import stats_time
-from tilert.models.deepseek_v3_2.model_args import ModelArgs
-from tilert.models.deepseek_v3_2.modules.end2end import ShowHandsDSALayer
-from tilert.models.deepseek_v3_2.temp_var_indices import Idx
+from tilert.models.glm_5._dsa_v32.generator import stats_time
+from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
+from tilert.models.glm_5._dsa_v32.modules.end2end import ShowHandsDSALayer
+from tilert.models.glm_5._dsa_v32.temp_var_indices import Idx
from tilert.tilert_init import tilert_init
__all__ = [
@@ -64,10 +64,9 @@ def __init__(
chat_template = f.read()
self.tokenizer.chat_template = chat_template
self.eos_id = self.tokenizer.eos_token_id
- self.batch_size = 1 # fixed batch size to 1 for now
+ self.batch_size = 1
self.mtp_seq_len = 4
- # GLM5 uses multiple stop tokens
self.stop_tokens = ["<|user|>", "<|endoftext|>", "<|observation|>", "<|assistant|>"]
self.stop_token_ids: set[int] = set()
for token in self.stop_tokens:
@@ -75,13 +74,11 @@ def __init__(
if len(token_ids) == 1:
self.stop_token_ids.add(token_ids[0])
else:
- # Try to get from added_tokens_encoder
if (
hasattr(self.tokenizer, "added_tokens_encoder")
and token in self.tokenizer.added_tokens_encoder
):
self.stop_token_ids.add(self.tokenizer.added_tokens_encoder[token])
- # Always include eos_id
if self.eos_id is not None:
self.stop_token_ids.add(self.eos_id)
logger.info(f"Stop token IDs: {self.stop_token_ids}")
@@ -113,6 +110,37 @@ def from_pretrained(self) -> None:
"""Load the model weights from the given path."""
self.decode_layer.from_pretrained(self.model_weights_dir)
+ def extract_ffn_cache(self) -> tuple[dict[int, list], dict[int, set[str]]]:
+ """Extract MOE/MLP op objects and skip keys from current loaded weights.
+
+ Returns:
+ Tuple of (cached_ffn_ops_per_device, skip_keys_per_device).
+ """
+ from tilert.models.glm_5._dsa_v32.modules.end2end import (
+ _extract_ffn_ops,
+ _get_moe_weight_keys,
+ )
+
+ cached_ffn_ops: dict[int, list] = {}
+ skip_keys: dict[int, set[str]] = {}
+ for device_id in range(self.decode_layer.num_devices):
+ dsa = self.decode_layer._dsa_objects[device_id]
+ if dsa is None:
+ raise RuntimeError(f"Device {device_id} Dsa not available for cache extraction")
+ cached_ffn_ops[device_id] = _extract_ffn_ops(dsa)
+ skip_keys[device_id] = _get_moe_weight_keys(dsa)
+ return cached_ffn_ops, skip_keys
+
+ def from_pretrained_with_cache(
+ self,
+ cached_ffn_ops_per_device: dict[int, list],
+ skip_keys_per_device: dict[int, set[str]],
+ ) -> None:
+ """Load weights reusing cached MOE/MLP ops."""
+ self.decode_layer.from_pretrained_with_cache(
+ self.model_weights_dir, cached_ffn_ops_per_device, skip_keys_per_device
+ )
+
def update_sampling_params(
self,
temperature: float = 1.0,
@@ -120,11 +148,7 @@ def update_sampling_params(
top_k: int = 256,
use_topp: bool = True,
) -> None:
- """Update sampling parameters for the next generation.
-
- Updates both the Python attributes and the CUDA sampling_config tensor
- that the TileRT kernel reads during forward pass.
- """
+ """Update sampling parameters for the next generation."""
self.temperature = temperature
self.decode_layer.update_sampling_config(
temperature=temperature, top_p=top_p, top_k=top_k, use_topp=use_topp
@@ -137,7 +161,7 @@ def generate(
print_log: bool = True,
with_mtp: bool | None = None,
prompt_tokens: list[int] | None = None,
- ) -> tuple[str, list[float], list[int]]:
+ ) -> tuple[str, list[float], list[int], int]:
"""Main function to load the model and perform single sequence generation.
Args:
@@ -149,7 +173,7 @@ def generate(
and use these tokens directly (useful for exact-length benchmarking).
Returns:
- Tuple of (result_text, time_list, accepted_counts).
+ Tuple of (result_text, time_list, accepted_counts, prompt_len).
accepted_counts is empty for non-MTP mode.
"""
active_mtp = with_mtp if with_mtp is not None else self.with_mtp
@@ -158,10 +182,10 @@ def generate(
self.decode_layer.set_sampling_seed(self.sampling_seed, with_mtp=active_mtp)
if active_mtp:
return self._generate_with_mtp(prompt, print_log, prompt_tokens=prompt_tokens)
- result, time_list = self._generate_without_mtp(
+ result, time_list, prompt_len = self._generate_without_mtp(
prompt, print_log, with_mtp=active_mtp, prompt_tokens=prompt_tokens
)
- return result, time_list, [] # Empty accepted_counts for non-MTP
+ return result, time_list, [], prompt_len
def _generate_without_mtp(
self,
@@ -169,7 +193,7 @@ def _generate_without_mtp(
print_log: bool = True,
with_mtp: bool = False,
prompt_tokens: list[int] | None = None,
- ) -> tuple[str, list[float]]:
+ ) -> tuple[str, list[float], int]:
"""Standard generation without MTP."""
if prompt_tokens is None:
messages = [{"role": "user", "content": prompt}]
@@ -179,11 +203,7 @@ def _generate_without_mtp(
add_generation_prompt=True,
enable_thinking=self.enable_thinking,
)
- # adapt to transformers 5.2.0
- if not isinstance(prompt_tokens, list) and prompt_tokens.get("input_ids") is not None:
- prompt_tokens = prompt_tokens["input_ids"]
- assert prompt_tokens is not None
max_seq_len = self.config.max_seq_len
prompt_len = len(prompt_tokens)
total_len = min(max_seq_len, self.max_new_tokens + prompt_len)
@@ -211,14 +231,12 @@ def _generate_without_mtp(
time_list.append(end_time - start_time)
intermediates, *_ = multi_devices_results[0]
- next_token = intermediates[Idx.TOKEN_OUT][0][0] # only the first token
+ next_token = intermediates[Idx.TOKEN_OUT][0][0]
- # replace the next token with the prompt token if the prompt mask is True
next_token = torch.where(
prompt_mask[0, cur_pos_val], tokens[0, cur_pos_val], next_token
)
tokens[0, cur_pos_val] = next_token
- # Check if next_token is any of the stop tokens
is_stop_token = next_token.item() in self.stop_token_ids
finished |= torch.logical_and(
~prompt_mask[0, cur_pos_val],
@@ -242,13 +260,11 @@ def _generate_without_mtp(
stats_time(time_list, "==== Performance ====")
print("\n")
- # Reset sequence after generation, i.e. reset the cur_pos to 0 internally
self.decode_layer.reset_sequence()
completion_tokens = []
for _, toks in enumerate(tokens.tolist()):
toks = toks[prompt_len : prompt_len + self.max_new_tokens]
- # Find first stop token and truncate
stop_idx = len(toks)
for i, tok in enumerate(toks):
if tok in self.stop_token_ids:
@@ -259,14 +275,14 @@ def _generate_without_mtp(
decoded_tokens = self.tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
- return f"{decoded_tokens[0]}\n" if decoded_tokens else "", time_list
+ return f"{decoded_tokens[0]}\n" if decoded_tokens else "", time_list, prompt_len
def _generate_with_mtp(
self,
prompt: str,
print_log: bool = True,
prompt_tokens: list[int] | None = None,
- ) -> tuple[str, list[float], list[int]]:
+ ) -> tuple[str, list[float], list[int], int]:
"""Generation with MTP (Multi-Token Prediction) speculative decoding."""
if prompt_tokens is None:
prompt_tokens = self.tokenizer.apply_chat_template(
@@ -274,16 +290,11 @@ def _generate_with_mtp(
add_generation_prompt=True,
enable_thinking=self.enable_thinking,
)
- # adapt to transformers 5.2.0
- if not isinstance(prompt_tokens, list) and prompt_tokens.get("input_ids") is not None:
- prompt_tokens = prompt_tokens["input_ids"]
- assert prompt_tokens is not None
max_seq_len = self.config.max_seq_len
prompt_len = len(prompt_tokens)
total_len = min(max_seq_len, self.max_new_tokens + prompt_len)
- # Output tokens buffer
tokens = torch.full(
(self.batch_size, total_len), -1, dtype=torch.long, device=self.default_device
)
@@ -293,17 +304,14 @@ def _generate_with_mtp(
prefill_time_list = []
decode_time_list = []
- decode_accepted_counts = [] # Only track decode phase for statistics
- cur_pos = 0 # Current position in the output sequence
+ decode_accepted_counts = []
+ cur_pos = 0
- # Prefill phase: process prompt tokens in non-overlapping chunks.
- # Each chunk fills unique KV cache positions for both main model and MTP[0].
while cur_pos < prompt_len - 1:
draft_end = min(cur_pos + self.mtp_seq_len, prompt_len)
draft_tokens = tokens[0, cur_pos:draft_end].clone()
actual_token_count = draft_tokens.shape[0]
- # Pad if needed (use last token for padding)
if actual_token_count < self.mtp_seq_len:
pad_token = draft_tokens[-1].item()
padding = torch.full(
@@ -316,18 +324,13 @@ def _generate_with_mtp(
draft_tokens = draft_tokens.reshape(1, self.mtp_seq_len).to(torch.int32)
- # Provide the extra token for MTP[0]'s shifted input last position.
- # MTP[0] needs tokens[cur_pos+1 : cur_pos+mtp_seq_len+1], so the
- # extra token is at cur_pos + mtp_seq_len.
mtp_extra_pos = cur_pos + self.mtp_seq_len
if mtp_extra_pos < prompt_len:
mtp_extra_token = int(tokens[0, mtp_extra_pos].item())
else:
- # Beyond prompt — use last valid draft token as padding
mtp_extra_token = int(tokens[0, draft_end - 1].item())
self.decode_layer.set_prefill_mtp_extra_token(mtp_extra_token)
- # Tell GPU how many tokens are valid (for cur_pos advancement)
self.decode_layer.set_prefill_valid_tokens(actual_token_count)
start_time = time.time()
@@ -335,27 +338,16 @@ def _generate_with_mtp(
end_time = time.time()
prefill_time_list.append(end_time - start_time)
- # No overlap: advance by the full actual_token_count
cur_pos += actual_token_count
- # After no-overlap prefill, cur_pos may have overshot to prompt_len.
- # Reset to prompt_len - 1 for correct decode start (first decode
- # reprocesses the last prompt token position).
cur_pos = prompt_len - 1
self.set_cur_pos(prompt_len - 1)
- # Decode phase: speculative decoding
- # Set prefill_valid_tokens to 0 to switch to decode mode
self.decode_layer.set_prefill_valid_tokens(0)
finished = False
while cur_pos < total_len - 1 and not finished:
- # Get next_draft_tokens from previous iteration
- # (or use last prompt tokens for first decode)
if cur_pos == prompt_len - 1:
- # First decode iteration: use last prompt token repeated as placeholder drafts
- # We can't use [t6, t7, t8, t9] because that would apply wrong RoPE positions
- # (cur_pos=9 means positions 9,10,11,12, but t6 should be at position 6)
last_token = tokens[0, prompt_len - 1].item()
draft_tokens = torch.full(
(self.mtp_seq_len,),
@@ -365,7 +357,6 @@ def _generate_with_mtp(
)
draft_tokens = draft_tokens.reshape(1, self.mtp_seq_len).to(torch.int32)
else:
- # Use next_draft_tokens from previous iteration
draft_tokens = self.decode_layer.get_next_draft_tokens(0).reshape(
1, self.mtp_seq_len
)
@@ -376,11 +367,9 @@ def _generate_with_mtp(
decode_time_list.append(end_time - start_time)
num_accepted = self.decode_layer.get_num_accepted(0)
- # Use predicted_tokens for output (not next_draft_tokens which is for next iteration)
predicted_tokens = self.decode_layer.get_predicted_tokens(0).flatten()
decode_accepted_counts.append(num_accepted)
- # Add accepted tokens to output
num_output_tokens = num_accepted
for i in range(num_output_tokens):
if cur_pos + 1 + i >= total_len:
@@ -388,12 +377,10 @@ def _generate_with_mtp(
new_token = int(predicted_tokens[i].item())
tokens[0, cur_pos + 1 + i] = new_token
- # Print generated token
if cur_pos + 1 + i >= prompt_len and print_log:
decoded_text = self.tokenizer.decode([new_token], skip_special_tokens=True)
print(decoded_text, end="", flush=True)
- # Check for any stop token
if new_token in self.stop_token_ids:
finished = True
break
@@ -414,7 +401,6 @@ def _generate_with_mtp(
f"min={min_accepted}, max={max_accepted}"
)
- # Calculate correct TPS accounting for MTP's multiple tokens per call
if decode_time_list:
total_decode_time = sum(decode_time_list)
effective_tps = total_tokens / total_decode_time if total_decode_time > 0 else 0
@@ -427,16 +413,12 @@ def _generate_with_mtp(
print("\n")
- # Reset sequence after generation
self.decode_layer.reset_sequence()
- # Extract completion tokens
completion_tokens = []
for _, toks in enumerate(tokens.tolist()):
toks = toks[prompt_len : prompt_len + self.max_new_tokens]
- # Remove -1 padding and tokens after any stop token
toks = [t for t in toks if t != -1]
- # Find first stop token and truncate
stop_idx = len(toks)
for i, tok in enumerate(toks):
if tok in self.stop_token_ids:
@@ -451,6 +433,7 @@ def _generate_with_mtp(
f"{decoded_tokens[0]}\n" if decoded_tokens else "",
decode_time_list,
decode_accepted_counts,
+ prompt_len,
)
def inject_cache(
@@ -490,7 +473,6 @@ def inject_cache(
logger.warning("inject_cache called with empty layer_caches")
return
- # Infer seqlen from first tensor if end_pos not specified
first_ki, _, _ = layer_caches[0]
seqlen = first_ki.size(0)
if end_pos is None:
@@ -509,13 +491,8 @@ def inject_cache(
logger.warning(f"Layer index {layer_id} is out of bounds, skipping.")
break
- # GLM-5 cache layout: 3 tensors per layer (ki, kv, pe)
- # Based on CacheVarsGlm5: k_cache, kv_cache, pe_cache
base_idx = layer_id * 3
- # Copy to device and inject into cache
- # Cache layout: [batch=1, max_seq_len, dim]
- # External data: [seqlen, dim]
ki_src = ki[:cache_len].to(f"cuda:{device_id}")
kv_src = kv[:cache_len].to(f"cuda:{device_id}")
pe_src = pe[:cache_len].to(f"cuda:{device_id}")
@@ -527,14 +504,11 @@ def inject_cache(
logger.info(f"Cache injection completed for {num_devices} devices")
def set_cur_pos(self, cur_pos: int) -> None:
- """Set the current position for RoPE in C++ backend.
-
- This should be called after inject_cache() to ensure the C++ global
- g_cur_pos matches the injected cache length. This is critical for
- correct RoPE position encoding during continued generation.
+ """Set the current position for RoPE.
- For MTP mode, sets the GPU tensor at intermediates[31] directly.
- For non-MTP mode, calls the C++ dsa_show_hands_set_cur_pos_glm5 API.
+ This should be called after inject_cache() to ensure the runtime position
+ matches the injected cache length, for correct RoPE position encoding
+ during continued generation.
Args:
cur_pos: The current sequence position (typically the length of prefilled tokens).
@@ -545,14 +519,12 @@ def set_cur_pos(self, cur_pos: int) -> None:
>>> # Now generate continues from the correct position
"""
if self.with_mtp:
- # MTP E2E uses cur_pos tensor in TempVars
num_devices = self.decode_layer.num_devices
for device_id in range(num_devices):
intermediates, _, _, _ = self.decode_layer._get_device_result(device_id)
cur_pos_tensor = intermediates[Idx.CUR_POS]
cur_pos_tensor.fill_(cur_pos)
else:
- # Non-MTP uses the C++ global g_cur_pos
torch.ops.tilert.dsa_show_hands_set_cur_pos_glm5(cur_pos)
logger.info(f"Set cur_pos to {cur_pos}")
@@ -560,8 +532,7 @@ def inject_last_hidden_state(self, last_hidden_state: torch.Tensor) -> None:
"""Inject the last hidden state for MTP mode.
For MTP (Multi-Token Prediction), the MTP preprocess layer needs the
- last hidden state from the main model's last token. This method injects
- the hidden state into intermediates[33] (last_hidden_states slot).
+ last hidden state from the main model's last token.
Args:
last_hidden_state: [hidden_size] or [1, hidden_size] BF16 tensor.
@@ -577,14 +548,12 @@ def inject_last_hidden_state(self, last_hidden_state: torch.Tensor) -> None:
logger.warning("inject_last_hidden_state called but with_mtp is False, skipping")
return
- # Normalize shape to [1, hidden_size]
if last_hidden_state.dim() == 1:
last_hidden_state = last_hidden_state.unsqueeze(0)
num_devices = self.decode_layer.num_devices
for device_id in range(num_devices):
intermediates, _, _, _ = self.decode_layer._get_device_result(device_id)
- # Shape: [batch=1, seq=4, hidden_size], we set seq[0] since it's the last token
lhs_tensor = intermediates[Idx.LAST_HIDDEN_STATES]
lhs_src = last_hidden_state.to(f"cuda:{device_id}")
lhs_tensor[0, 0, :].copy_(lhs_src.squeeze(0))
diff --git a/python/models/glm_5/model_args.py b/tilert/models/glm_5/model_args.py
similarity index 92%
rename from python/models/glm_5/model_args.py
rename to tilert/models/glm_5/model_args.py
index a64ed6f..74e830c 100644
--- a/python/models/glm_5/model_args.py
+++ b/tilert/models/glm_5/model_args.py
@@ -3,7 +3,7 @@
from dataclasses import dataclass
from typing import Literal
-from tilert.models.deepseek_v3_2.model_args import ModelArgs
+from tilert.models.glm_5._dsa_v32.model_args import ModelArgs
__all__ = [
"ModelArgsGLM5",
@@ -52,7 +52,7 @@ class ModelArgsGLM5(ModelArgs):
arch_name = "glm_5"
- max_batch_size: int = 1 # NOTE: the current implementation only supports a batch size being 1
+ max_batch_size: int = 1
max_seq_len: int = 202752
dtype: Literal["bf16", "fp8"] = "fp8"
scale_fmt: str | None = None
@@ -65,23 +65,18 @@ class ModelArgsGLM5(ModelArgs):
n_dense_layers: int = 3
n_heads: int = 64
- # moe
n_routed_experts: int = 256
n_shared_experts: int = 1
n_activated_experts: int = 8
- # n_expert_groups: int = 8
- # n_limited_groups: int = 4
score_func: Literal["softmax", "sigmoid"] = "softmax"
route_scale: float = 2.5
- # mla
q_lora_rank: int = 2048
kv_lora_rank: int = 512
qk_nope_head_dim: int = 192
qk_rope_head_dim: int = 64
v_head_dim: int = 256
- # yarn
original_seq_len: int | None = None
rope_theta: float = 1000000.0
rope_factor: float | None = None
@@ -89,12 +84,10 @@ class ModelArgsGLM5(ModelArgs):
beta_slow: int | None = None
mscale: float = 1.0
- # index
index_n_heads: int = 32
index_head_dim: int = 128
index_topk: int = 2048
- # quant
block_size: int = 128
eps: float = 1e-5
diff --git a/python/models/preprocess/weight_converter.py b/tilert/models/preprocess/weight_converter.py
similarity index 90%
rename from python/models/preprocess/weight_converter.py
rename to tilert/models/preprocess/weight_converter.py
index c0926aa..4973f69 100644
--- a/python/models/preprocess/weight_converter.py
+++ b/tilert/models/preprocess/weight_converter.py
@@ -2,7 +2,7 @@
import os
import pprint
from collections import OrderedDict
-from typing import Any, TypedDict
+from typing import Any, TypedDict, cast
import torch
from safetensors.torch import load_file, save_file
@@ -10,7 +10,7 @@
from tilert import logger
from tilert.models.deepseek_v3_2.model_args import ModelArgs
from tilert.models.deepseek_v3_2.model_args import ModelArgs as ModelArgsDsav32
-from tilert.models.deepseek_v3_2.modules.mla import Mla
+from tilert.models.deepseek_v3_2.modules.mla_v2 import PureMlaV2, SparseSelectMlaV2
from tilert.models.deepseek_v3_2.ops.down_allreduce import DownAllReduce
from tilert.models.deepseek_v3_2.ops.eh_proj_allreduce import EHProjAllReduce
from tilert.models.deepseek_v3_2.ops.expert_down_allreduce import ExpertDownAllReduce
@@ -36,13 +36,13 @@ class WeightConverter:
def __init__(
self,
- model_args: ModelArgs,
+ model_args: ModelArgs | ModelArgsGLM5,
num_devices: int,
model_dir: str,
save_dir: str,
test_mode: bool = False,
) -> None:
- self.model_args = model_args
+ self.model_args = cast(ModelArgs, model_args)
self.num_devices = num_devices
self.model_dir = model_dir
self.save_dir = save_dir
@@ -62,8 +62,6 @@ def __init__(
self.index_file = "model.safetensors.index.json"
self.__check_dir()
- # specially treated the embedding, norm, and head weights
- # at the beginning and end of the model
self.emb_name = "model.embed_tokens.weight"
self.norm_name = "model.norm.weight"
self.head_name = "lm_head.weight"
@@ -156,17 +154,16 @@ def save_file_sharded(
max_size_bytes = self.parse_size(max_shard_size)
- tensor_nums = len(weights_dict) # placeholder for number sharded files
+ tensor_nums = len(weights_dict)
shards: list[ShardInfo] = []
current_shard: dict[str, torch.Tensor] = {}
current_size = 0
- shard_index = 1 # first shard is for embedding
+ shard_index = 1
def get_shard_filename(shard_index: int) -> str:
return f"{base_filename}-{shard_index:05d}-of-{tensor_nums:05d}.safetensors"
- # Save embedding tensor to separate file
save_file(self.emb_weights_dict, get_shard_filename(shard_index))
shards.append(
{
@@ -184,10 +181,8 @@ def get_shard_filename(shard_index: int) -> str:
{name: self.get_tensor_size_bytes(tensor) for name, tensor in dev_tensors.items()}
)
- # If adding this tensor would exceed max size, start a new shard
for tensor_name, tensor_size in tensor_sizes.items():
if current_size + tensor_size > max_size_bytes and current_shard:
- # Save current shard
shard_filename = get_shard_filename(shard_index)
logger.info(f"Saving shard {shard_index} to {shard_filename}")
save_file(current_shard, shard_filename)
@@ -199,11 +194,9 @@ def get_shard_filename(shard_index: int) -> str:
current_size = 0
shard_index += 1
- # Add tensor to current shard
current_shard[tensor_name] = dev_tensors[tensor_name]
current_size += tensor_size
- # Save the last shard for the current device
if current_shard:
shard_filename = get_shard_filename(shard_index)
logger.info(f"Saving shard {shard_index} to {shard_filename}")
@@ -213,7 +206,6 @@ def get_shard_filename(shard_index: int) -> str:
current_size = 0
shard_index += 1
- # Update shard filenames with correct total count
total_shards = len(shards)
for i, shard in enumerate(shards, 1):
old_filename = shard["filename"]
@@ -253,24 +245,31 @@ def transform_mla(
weights_hf: dict[str, torch.Tensor],
layer_id: int,
) -> dict[str, dict[str, torch.Tensor]]:
- mla_weights_map: dict[str, dict[str, torch.Tensor]] = {}
- for dev_id in range(self.num_devices):
- mla_weights_map.setdefault(f"dev_{dev_id}", {})
- mla = Mla(self.model_args, device_id=0, num_devices=self.num_devices)
- mla_raw_dict = {
- _k: weights_hf[f"model.layers.{layer_id}.{_k}"] for _k in mla.get_ref_weights_alias()
+ """Shard MLA weights across devices."""
+ mla_weights: dict[str, dict[str, torch.Tensor]] = {
+ f"dev_{dev_id}": {} for dev_id in range(self.num_devices)
}
- mla_sharded_dict = mla.device_sharding(mla_raw_dict)
- for dev_id in range(self.num_devices):
- for key, value in mla_sharded_dict.items():
- mla_weights_map[f"dev_{dev_id}"].update({key: value[dev_id].contiguous()})
- mla_weights = {}
- for dev_id in range(self.num_devices):
- mla_weights_dev = {}
- for key in mla_weights_map[f"dev_{dev_id}"].keys():
- mla_weights_dev.update({key: mla_weights_map[f"dev_{dev_id}"][key]})
- mla_weights.update({f"dev_{dev_id}": mla_weights_dev})
+ sparse_mla = SparseSelectMlaV2(self.model_args, device_id=0, num_devices=1)
+ sparse_raw_dict = {
+ _k: weights_hf[f"model.layers.{layer_id}.{_k}"]
+ for _k in sparse_mla.get_ref_weights_alias()
+ }
+ sparse_sharded = sparse_mla.device_sharding(sparse_raw_dict)
+ for key, value in sparse_sharded.items():
+ mla_weights["dev_0"][key] = value[0].contiguous()
+
+ num_pure_mla_devices = self.num_devices - 1
+ pure_mla = PureMlaV2(self.model_args, device_id=0, num_devices=num_pure_mla_devices)
+ pure_raw_dict = {
+ _k: weights_hf[f"model.layers.{layer_id}.{_k}"]
+ for _k in pure_mla.get_ref_weights_alias()
+ }
+ pure_sharded = pure_mla.device_sharding(pure_raw_dict)
+ for shard_idx in range(num_pure_mla_devices):
+ gpu_id = shard_idx + 1
+ for key, value in pure_sharded.items():
+ mla_weights[f"dev_{gpu_id}"][key] = value[shard_idx].contiguous()
return mla_weights
@@ -328,7 +327,6 @@ def transform_mlp(
layer_id: int,
) -> dict[str, dict[str, torch.Tensor]]:
"""Transform MLP weights."""
- print(RMSNormUpGateSiLU)
rmsnorm_up_gate_silu = RMSNormUpGateSiLU(
self.model_args, device_id=0, num_devices=self.num_devices
)
@@ -384,13 +382,7 @@ def transform_mtp(
weights_hf: dict[str, torch.Tensor],
layer_id: int,
) -> dict[str, dict[str, torch.Tensor]]:
- """Transform MTP weights.
-
- Transformations applied:
- - enorm.weight: Direct use (fp32)
- - hnorm.weight: Direct use (fp32)
- - eh_proj.weight: Split along dim 1, reshape [7168, 1792] -> [128, 7, 56, 256]
- """
+ """Transform MTP weights."""
enorm_weight_key = f"model.layers.{layer_id}.enorm.weight"
hnorm_weight_key = f"model.layers.{layer_id}.hnorm.weight"
enorm_weight = weights_hf[enorm_weight_key]
@@ -530,7 +522,7 @@ def _sort_key(filename: str) -> tuple[int, int]:
try:
return _get_layer_num(filename)
except ValueError:
- return (999999, 999999) # If layer number not found, put at the end
+ return (999999, 999999)
tilert_weights = sorted(
self.converted_weights_dict, key=lambda x: _sort_key(x), reverse=False
@@ -566,7 +558,6 @@ def append_mtp_weights_to_safetensors(
"""
torch.set_default_device(self.default_device)
- # Load existing index.json
existing_index_file = os.path.join(existing_save_dir, "model.safetensors.index.json")
if not os.path.exists(existing_index_file):
raise ValueError(f"Existing index file not found: {existing_index_file}")
@@ -577,11 +568,9 @@ def append_mtp_weights_to_safetensors(
existing_weight_map: dict[str, str] = existing_index["weight_map"]
existing_total_size: int = existing_index["metadata"]["total_size"]
- # Find the next shard number
existing_shards = set(existing_weight_map.values())
max_shard_num = 0
for shard_name in existing_shards:
- # Parse shard number from filename like "model.safetensors-00001-of-00010.safetensors"
parts = shard_name.replace(".safetensors", "").split("-")
if len(parts) >= 2:
try:
@@ -594,14 +583,11 @@ def append_mtp_weights_to_safetensors(
f"Found {len(existing_shards)} existing shards, max shard number: {max_shard_num}"
)
- # Convert MTP layer (layer 61) weights
- mtp_layer_idx = self.num_dense_layers + self.num_moe_layers # 61
+ mtp_layer_idx = self.num_dense_layers + self.num_moe_layers
logger.info(f"Converting MTP layer {mtp_layer_idx} weights...")
mla_weights, mlp_weights, mtp_weights = self.convert_a_layer(mtp_layer_idx)
- # Collect MTP layer weights for all devices
- # Clone tensors to avoid shared memory issues when saving to safetensors
mtp_layer_weights: dict[str, torch.Tensor] = {}
for weights_group in [mla_weights, mlp_weights, mtp_weights]:
for dev, params in weights_group.items():
@@ -611,21 +597,17 @@ def append_mtp_weights_to_safetensors(
logger.info(f"Collected {len(mtp_layer_weights)} MTP layer weight tensors")
- # Calculate size of new weights
new_weights_size = sum(self.get_tensor_size_bytes(t) for t in mtp_layer_weights.values())
- # Save MTP weights to new shard file(s)
- # Use a separate naming scheme to avoid modifying existing shards
max_size_bytes = self.parse_size(max_shard_size)
new_shards: list[ShardInfo] = []
current_shard: dict[str, torch.Tensor] = {}
current_size = 0
- mtp_shard_index = 1 # Start from 1 for MTP shards
+ mtp_shard_index = 1
for tensor_name, tensor in mtp_layer_weights.items():
tensor_size = self.get_tensor_size_bytes(tensor)
if current_size + tensor_size > max_size_bytes and current_shard:
- # Save current shard with MTP-specific naming
shard_filename = f"model_mtp_layer61-{mtp_shard_index:05d}.safetensors"
shard_path = os.path.join(existing_save_dir, shard_filename)
logger.info(f"Saving MTP shard to {shard_filename}")
@@ -640,7 +622,6 @@ def append_mtp_weights_to_safetensors(
current_shard[tensor_name] = tensor
current_size += tensor_size
- # Save the last shard
if current_shard:
shard_filename = f"model_mtp_layer61-{mtp_shard_index:05d}.safetensors"
shard_path = os.path.join(existing_save_dir, shard_filename)
@@ -648,12 +629,10 @@ def append_mtp_weights_to_safetensors(
save_file(current_shard, shard_path)
new_shards.append({"filename": shard_filename, "tensors": list(current_shard.keys())})
- # Update weight_map with new MTP weights (existing shards remain unchanged)
for shard in new_shards:
for tensor_name in shard["tensors"]:
existing_weight_map[tensor_name] = shard["filename"]
- # Update index.json
updated_index = {
"metadata": {"total_size": existing_total_size + new_weights_size},
"weight_map": existing_weight_map,
@@ -683,6 +662,7 @@ def append_mtp_weights_to_safetensors(
args = parser.parse_args()
model_type = args.model_type
+ model_args: ModelArgsDsav32 | ModelArgsGLM5
if model_type == "deepseek-v32":
model_args = ModelArgsDsav32()
elif model_type == "glm-5":
diff --git a/python/models/utils.py b/tilert/models/utils.py
similarity index 76%
rename from python/models/utils.py
rename to tilert/models/utils.py
index b5e81e5..8caaaee 100644
--- a/python/models/utils.py
+++ b/tilert/models/utils.py
@@ -10,13 +10,25 @@
import torch
+_FACTOR_OVERRIDE_UNSET = object()
+_THETA_OVERRIDE_UNSET = object()
-def precompute_freqs_cis(args) -> torch.Tensor: # type: ignore
+
+def precompute_freqs_cis( # type: ignore[no-untyped-def]
+ args,
+ *,
+ factor_override=_FACTOR_OVERRIDE_UNSET,
+ theta_override=_THETA_OVERRIDE_UNSET,
+) -> torch.Tensor:
"""
Pre-computes frequency-based complex exponential values for rotary positional embeddings.
Args:
args (ModelArgs): Model arguments containing positional embedding parameters.
+ factor_override: If unset, ``args.rope_factor`` is used. Pass a
+ numeric value to override the factor inline.
+ theta_override: If unset, ``args.rope_theta`` is used. Pass a numeric
+ value to override the rope base. ``None`` is rejected.
Returns:
torch.Tensor: Precomputed complex exponential values for positional embeddings.
@@ -25,8 +37,8 @@ def precompute_freqs_cis(args) -> torch.Tensor: # type: ignore
seqlen = args.max_seq_len
beta_fast = args.beta_fast
beta_slow = args.beta_slow
- base = args.rope_theta
- factor = args.rope_factor
+ base = args.rope_theta if theta_override is _THETA_OVERRIDE_UNSET else theta_override
+ factor = args.rope_factor if factor_override is _FACTOR_OVERRIDE_UNSET else factor_override
def find_correction_dim(num_rotations: float, dim: int, base: float, max_seq_len: int) -> float:
"""
@@ -106,27 +118,33 @@ def linear_ramp_factor(min_value: float, max_value: float, dim: int) -> torch.Te
return torch.polar(torch.ones_like(freqs), freqs)
-def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
- """
-
- Applies rotary positional embeddings to the input tensor.
+def apply_rotary_emb(
+ x_in: torch.Tensor, freqs_cis: torch.Tensor, interleaved: bool = True
+) -> torch.Tensor:
+ """Applies rotary positional embeddings to the input tensor.
Args:
- x (torch.Tensor): Input tensor with positional embeddings to be applied.
- freqs_cis (torch.Tensor): Precomputed complex exponential values for
- positional embeddings.
+ x_in: Input tensor with positional embeddings to be applied.
+ freqs_cis: Precomputed complex exponential values for positional embeddings.
+ interleaved: If True (default), adjacent pairs (x0,x1),(x2,x3)... form
+ complex numbers. If False, half-half layout: (x0,x_{d/2}),(x1,x_{d/2+1})...
+ The DeepSeek-V3.2-Exp indexer uses interleaved=False.
Returns:
torch.Tensor: Tensor with rotary embeddings applied.
"""
dtype = x_in.dtype
- x_in = torch.view_as_complex(x_in.float().view(*x_in.shape[:-1], -1, 2))
+ shape = x_in.shape
+ if not interleaved:
+ x_in = x_in.view(*shape[:-1], 2, -1).transpose(-1, -2).contiguous()
+ x_in = torch.view_as_complex(x_in.float().view(*shape[:-1], -1, 2))
freqs_cis = freqs_cis.view(1, x_in.size(1), 1, x_in.size(-1))
y_out = torch.view_as_real(x_in * freqs_cis).flatten(3)
+ if not interleaved:
+ y_out = torch.cat([y_out[..., 0::2], y_out[..., 1::2]], dim=-1)
return y_out.to(dtype)
-# enumerate swizzle mode
class SwizzleMode(IntEnum):
"""Swizzle mode."""
@@ -136,7 +154,6 @@ class SwizzleMode(IntEnum):
SWIZZLE_128B = 128 // 16
-# See CUDA C++ programming Guide 10.29.3.2 for more details.
def gen_tensor_swizzle_map_1d(
rows: int, cols_in_16bytes: int, swizzle_mode: SwizzleMode = SwizzleMode.SWIZZLE_128B
) -> torch.Tensor:
diff --git a/python/tilert_init.py b/tilert/tilert_init.py
similarity index 100%
rename from python/tilert_init.py
rename to tilert/tilert_init.py
diff --git a/python/utils.py b/tilert/utils.py
similarity index 62%
rename from python/utils.py
rename to tilert/utils.py
index 47335d7..4cc3b47 100644
--- a/python/utils.py
+++ b/tilert/utils.py
@@ -5,28 +5,35 @@
import torch
__all__ = [
+ "alloc_misc_ws",
"cosine_similarity",
"relative_l2_error",
"get_profile_log_tensor",
"SLICES_FOR_TILERT_OP",
]
-
SLICES_FOR_TILERT_OP = 1
def get_profile_log_tensor(
- device_index: int = 0, device: torch.device | None = None, num_max_insts: int = 64
-) -> torch.Tensor:
+ device_index: int = 0,
+ device: torch.device | None = None,
+ num_max_insts: int = 64,
+) -> torch.Tensor | None:
"""Get a profile log tensor for the given device index.
+ Returns ``None`` when no CUDA GPUs are visible so the offline
+ weight-conversion path can run with ``CUDA_VISIBLE_DEVICES=""``.
+
Args:
device_index: The index of the device.
device: The device to use.
Returns:
- A profile log tensor.
+ A profile log tensor, or ``None`` if CUDA is unavailable.
"""
+ if not torch.cuda.is_available():
+ return None
if device is None:
device = torch.device("cuda", device_index)
@@ -38,6 +45,23 @@ def get_profile_log_tensor(
)
+def alloc_misc_ws(
+ num_max_insts: int = 64,
+ device_id: int = 0,
+) -> torch.Tensor:
+ """Allocate a misc workspace tensor.
+
+ Args:
+ num_max_insts: Maximum number of profiled instructions.
+ device_id: CUDA device index to allocate on.
+
+ Returns:
+ A zeroed int64 tensor of shape (total_rows, num_sm, 16) on the
+ requested CUDA device.
+ """
+ return torch.ops.tilert.alloc_misc_ws(num_max_insts, device_id)
+
+
def cosine_similarity(gt: torch.Tensor, out: torch.Tensor) -> torch.Tensor:
"""Calculate the cosine similarity.