From 3772812eba2f1c332563f072d2ce03f3b3e1c611 Mon Sep 17 00:00:00 2001 From: xieck13 Date: Tue, 9 Sep 2025 22:35:43 +0800 Subject: [PATCH] Add Dockerfiles for SGLang and TE FP8 with necessary patches - Introduced Dockerfile for SGLang with CUDA 12.9.1 - Added Dockerfile for TE FP8, which builds on the SGLang image and installs additional packages and patches for Megatron-LM and SGLang. - Included patches for Megatron-LM and SGLang --- docker/Dockerfile.sglang.cu129 | 391 +++++++++++++++++++++++++ docker/Dockerfile.te_fp8.cu129 | 65 +++++ docker/patch/te_fp8/megatron.patch | 439 +++++++++++++++++++++++++++++ docker/patch/te_fp8/sglang.patch | 236 ++++++++++++++++ 4 files changed, 1131 insertions(+) create mode 100644 docker/Dockerfile.sglang.cu129 create mode 100644 docker/Dockerfile.te_fp8.cu129 create mode 100644 docker/patch/te_fp8/megatron.patch create mode 100644 docker/patch/te_fp8/sglang.patch diff --git a/docker/Dockerfile.sglang.cu129 b/docker/Dockerfile.sglang.cu129 new file mode 100644 index 0000000000..9436a398e2 --- /dev/null +++ b/docker/Dockerfile.sglang.cu129 @@ -0,0 +1,391 @@ +ARG CUDA_VERSION=12.9.1 +FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu22.04 as base + +ARG BUILD_TYPE=all +ARG BRANCH_TYPE=remote +ARG DEEPEP_COMMIT=b92d0d4860ce6866cd6d31bfbae937f9a7a3772b +ARG CMAKE_BUILD_PARALLEL_LEVEL=2 +ENV DEBIAN_FRONTEND=noninteractive \ + CUDA_HOME=/usr/local/cuda \ + GDRCOPY_HOME=/usr/src/gdrdrv-2.4.4/ \ + NVSHMEM_DIR=/sgl-workspace/nvshmem/install +# Add GKE default lib and bin locations. +ENV PATH="${PATH}:/usr/local/nvidia/bin" \ + LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/nvidia/lib:/usr/local/nvidia/lib64" + +RUN apt update && apt install wget -y && apt install software-properties-common -y \ + && add-apt-repository ppa:deadsnakes/ppa -y \ + && apt install python3.12-full python3.12-dev python3.10-venv -y \ + && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1 \ + && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 2 \ + && update-alternatives --set python3 /usr/bin/python3.12 \ + && wget https://bootstrap.pypa.io/get-pip.py \ + && python3 get-pip.py + +# Set timezone and install all packages +RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ + && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \ + && apt-get update && apt-get install -y --no-install-recommends \ + tzdata \ + software-properties-common netcat-openbsd kmod unzip openssh-server \ + curl wget lsof zsh ccache tmux htop git-lfs tree \ + build-essential cmake \ + libopenmpi-dev libnuma1 libnuma-dev \ + libibverbs-dev libibverbs1 libibumad3 \ + librdmacm1 libnl-3-200 libnl-route-3-200 libnl-route-3-dev libnl-3-dev \ + ibverbs-providers infiniband-diags perftest \ + libgoogle-glog-dev libgtest-dev libjsoncpp-dev libunwind-dev \ + libboost-all-dev libssl-dev \ + libgrpc-dev libgrpc++-dev libprotobuf-dev protobuf-compiler protobuf-compiler-grpc \ + pybind11-dev \ + libhiredis-dev libcurl4-openssl-dev \ + libczmq4 libczmq-dev \ + libfabric-dev \ + patchelf \ + nvidia-dkms-550 \ + devscripts debhelper fakeroot dkms check libsubunit0 libsubunit-dev \ + && ln -sf /usr/bin/python3.12 /usr/bin/python \ + && rm -rf /var/lib/apt/lists/* \ + && apt-get clean + +# GDRCopy installation +RUN mkdir -p /tmp/gdrcopy && cd /tmp \ + && git clone https://github.com/NVIDIA/gdrcopy.git -b v2.4.4 \ + && cd gdrcopy/packages \ + && CUDA=/usr/local/cuda ./build-deb-packages.sh \ + && dpkg -i gdrdrv-dkms_*.deb libgdrapi_*.deb gdrcopy-tests_*.deb gdrcopy_*.deb \ + && cd / && rm -rf /tmp/gdrcopy + +# Fix DeepEP IBGDA symlink +RUN ln -sf /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so + +FROM scratch AS local_src +COPY . /src + +FROM base AS build-image +# Install SGLang +WORKDIR /sgl-workspace +ARG BRANCH_TYPE +COPY --from=local_src /src /tmp/local_src +RUN if [ "$BRANCH_TYPE" = "local" ]; then \ + cp -r /tmp/local_src /sgl-workspace/sglang; \ + else \ + git clone --depth=1 https://github.com/sgl-project/sglang.git /sgl-workspace/sglang; \ + fi \ + && rm -rf /tmp/local_src +RUN python3 -m pip install --no-cache-dir --upgrade pip setuptools wheel html5lib six \ + && cd sglang \ + && case "$CUDA_VERSION" in \ + 12.6.1) CUINDEX=126 ;; \ + 12.8.1) CUINDEX=128 ;; \ + 12.9.1) CUINDEX=129 ;; \ + *) echo "Unsupported CUDA version: $CUDA_VERSION" && exit 1 ;; \ + esac \ + && python3 -m pip install --no-cache-dir -e "python[${BUILD_TYPE}]" --extra-index-url https://download.pytorch.org/whl/cu${CUINDEX} \ + && python3 -m pip install --no-cache-dir nvidia-nccl-cu12==2.27.6 --force-reinstall --no-deps \ + && python3 -m flashinfer --download-cubin \ + && if [ "$CUDA_VERSION" = "12.8.1" ]; then \ + python3 -m pip install --no-cache-dir https://github.com/sgl-project/whl/releases/download/v0.3.8/sgl_kernel-0.3.8+cu128-cp310-abi3-manylinux2014_x86_64.whl --force-reinstall --no-deps ; \ + fi \ + && if [ "$CUDA_VERSION" = "12.9.1" ]; then \ + python3 -m pip install --no-cache-dir https://github.com/sgl-project/whl/releases/download/v0.3.8/sgl_kernel-0.3.8+cu129-cp310-abi3-manylinux2014_x86_64.whl --force-reinstall --no-deps ; \ + fi + +# Download source files +RUN wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.3.9/source/nvshmem_src_cuda12-all-all-3.3.9.tar.gz && \ + git clone https://github.com/deepseek-ai/DeepEP.git && \ + cd DeepEP && git checkout ${DEEPEP_COMMIT} && sed -i 's/#define NUM_CPU_TIMEOUT_SECS 100/#define NUM_CPU_TIMEOUT_SECS 1000/' csrc/kernels/configs.cuh && \ + cd .. && \ + tar -xf nvshmem_src_cuda12-all-all-3.3.9.tar.gz && \ + mv nvshmem_src nvshmem && \ + rm -f /sgl-workspace/nvshmem_src_cuda12-all-all-3.3.9.tar.gz + +# Build and install NVSHMEM +RUN cd /sgl-workspace/nvshmem && \ + NVSHMEM_SHMEM_SUPPORT=0 \ + NVSHMEM_UCX_SUPPORT=0 \ + NVSHMEM_USE_NCCL=0 \ + NVSHMEM_MPI_SUPPORT=0 \ + NVSHMEM_IBGDA_SUPPORT=1 \ + NVSHMEM_PMIX_SUPPORT=0 \ + NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \ + NVSHMEM_USE_GDRCOPY=1 \ + cmake -S . -B build/ -DCMAKE_INSTALL_PREFIX=${NVSHMEM_DIR} -DCMAKE_CUDA_ARCHITECTURES="90" && \ + cmake --build build --target install -j${CMAKE_BUILD_PARALLEL_LEVEL} + +# Install DeepEP +RUN cd /sgl-workspace/DeepEP && \ + case "$CUDA_VERSION" in \ + 12.6.1) \ + CHOSEN_TORCH_CUDA_ARCH_LIST='9.0' \ + ;; \ + 12.8.1|12.9.1) \ + CHOSEN_TORCH_CUDA_ARCH_LIST='9.0;10.0' \ + ;; \ + *) \ + echo "Unsupported CUDA version: $CUDA_VERSION" && exit 1 \ + ;; \ + esac && \ + NVSHMEM_DIR=${NVSHMEM_DIR} TORCH_CUDA_ARCH_LIST="${CHOSEN_TORCH_CUDA_ARCH_LIST}" pip install . + +# Python tools +RUN python3 -m pip install --no-cache-dir \ + datamodel_code_generator \ + mooncake-transfer-engine==0.3.5 \ + pre-commit \ + pytest \ + black \ + isort \ + icdiff \ + uv \ + wheel \ + scikit-build-core \ + nixl \ + py-spy + +# Install development tools and utilities +RUN apt-get update && apt-get install -y \ + gdb \ + ninja-build \ + vim \ + tmux \ + htop \ + wget \ + curl \ + locales \ + lsof \ + git \ + git-lfs \ + zsh \ + tree \ + silversearcher-ag \ + cloc \ + unzip \ + pkg-config \ + libssl-dev \ + bear \ + ccache \ + less \ + && apt install -y rdma-core infiniband-diags openssh-server perftest ibverbs-providers libibumad3 libibverbs1 libnl-3-200 libnl-route-3-200 librdmacm1 \ + && rm -rf /var/lib/apt/lists/* \ + && apt-get clean + +RUN apt update -y \ + && apt install -y --no-install-recommends gnupg \ + && echo "deb http://developer.download.nvidia.com/devtools/repos/ubuntu2004/amd64 /" | tee /etc/apt/sources.list.d/nvidia-devtools.list \ + && apt-key adv --fetch-keys http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub \ + && apt update -y \ + && apt install nsight-systems-cli -y + +# Set up locale +RUN locale-gen en_US.UTF-8 +ENV LANG en_US.UTF-8 +ENV LANGUAGE en_US:en +ENV LC_ALL en_US.UTF-8 + +# Install minimal Python packages +RUN python3 -m pip install --no-cache-dir --break-system-packages \ + pytest \ + black \ + isort \ + icdiff \ + scikit_build_core \ + uv \ + pre-commit \ + pandas \ + matplotlib \ + tabulate + +# Install diff-so-fancy +RUN curl -LSso /usr/local/bin/diff-so-fancy https://github.com/so-fancy/diff-so-fancy/releases/download/v1.4.4/diff-so-fancy \ + && chmod +x /usr/local/bin/diff-so-fancy + +# Install clang-format +RUN curl -LSso /usr/local/bin/clang-format https://github.com/muttleyxd/clang-tools-static-binaries/releases/download/master-32d3ac78/clang-format-16_linux-amd64 \ + && chmod +x /usr/local/bin/clang-format + +# Install clangd +RUN curl -L https://github.com/clangd/clangd/releases/download/18.1.3/clangd-linux-18.1.3.zip -o clangd.zip \ + && unzip clangd.zip \ + && cp -r clangd_18.1.3/bin/* /usr/local/bin/ \ + && cp -r clangd_18.1.3/lib/* /usr/local/lib/ \ + && rm -rf clangd_18.1.3 clangd.zip + +# Install CMake +RUN wget https://github.com/Kitware/CMake/releases/download/v3.31.1/cmake-3.31.1-linux-x86_64.tar.gz \ + && tar -xzf cmake-3.31.1-linux-x86_64.tar.gz \ + && cp -r cmake-3.31.1-linux-x86_64/bin/* /usr/local/bin/ \ + && cp -r cmake-3.31.1-linux-x86_64/share/* /usr/local/share/ \ + && rm -rf cmake-3.31.1-linux-x86_64 cmake-3.31.1-linux-x86_64.tar.gz + +# Install Rust toolchain for sgl-router +ENV PATH="/root/.cargo/bin:${PATH}" +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y \ + && rustc --version && cargo --version + +# Build and install sgl-router +RUN python3 -m pip install --no-cache-dir setuptools-rust \ + && cd /sgl-workspace/sglang/sgl-router \ + && cargo build --release \ + && python3 -m pip install --no-cache-dir . \ + && rm -rf /root/.cache + + +# Add yank script +COPY --chown=root:root <<-"EOF" /usr/local/bin/yank +#!/bin/bash +put() { + esc=$1 + test -n "$TMUX" -o -z "${TERM##screen*}" && esc="\033Ptmux;\033$esc\033\\" + printf "$esc" +} +put "\033]52;c;!\a" +buf=$( cat "$@" ) +len=$( printf %s "$buf" | wc -c ) max=74994 +test $len -gt $max && echo "$0: input is $(( len - max )) bytes too long" >&2 +put "\033]52;c;$( printf %s "$buf" | head -c $max | base64 | tr -d '\r\n' )\a" +test -n "$TMUX" && tmux set-buffer "$buf" ||: +EOF + +RUN chmod +x /usr/local/bin/yank + +# Install oh-my-zsh and plugins +RUN sh -c "$(curl -fsSL https://raw.githubusercontent.com/ohmyzsh/ohmyzsh/master/tools/install.sh)" "" --unattended \ + && git clone https://github.com/zsh-users/zsh-autosuggestions ${ZSH_CUSTOM:-~/.oh-my-zsh/custom}/plugins/zsh-autosuggestions \ + && git clone https://github.com/zsh-users/zsh-syntax-highlighting.git ${ZSH_CUSTOM:-~/.oh-my-zsh/custom}/plugins/zsh-syntax-highlighting + +# Configure Vim +COPY --chown=root:root <<-"EOF" /root/.vimrc +function! Yank(text) abort + let escape = system('yank', a:text) + if v:shell_error + echoerr escape + else + call writefile([escape], '/dev/tty', 'b') + endif +endfunction + +noremap y y:call Yank(@0) + +" automatically run yank(1) whenever yanking in Vim +function! CopyYank() abort + call Yank(join(v:event.regcontents, "\n")) +endfunction + +autocmd TextYankPost * call CopyYank() + +" Basic settings +set number +syntax on +set mouse=a +filetype indent on + +" Indentation +set autoindent nosmartindent +set smarttab +set expandtab +set shiftwidth=4 +set softtabstop=4 + +" Visual guides +set colorcolumn=120 +highlight ColorColumn ctermbg=5 + +" Status line +set laststatus=2 +set statusline=%<%f\ %h%m%r%=%{\"[\".(&fenc==\"\"?&enc:&fenc).((exists(\"+bomb\")\ &&\ &bomb)?\",B\":\"\").\"]\ \"}%k\ %-14.(%l,%c%V%)\ %P + +" Backspace behavior +set backspace=2 + +" Encoding +set encoding=utf-8 +set fileencoding=utf-8 +EOF + +# Configure tmux +COPY --chown=root:root <<-"EOF" /root/.tmux.conf +# Pane border styling +set -g pane-border-style fg='#742727',bg=black +set -g pane-active-border-style fg=red,bg=black + +# Status bar styling +set -g status-style bg='#0C8A92',fg=black + +# Change prefix key to backtick +set-option -g prefix ` +unbind C-b +bind-key ` send-prefix + +# Split panes using - and = with current path +unbind '"' +bind - splitw -v -c '#{pane_current_path}' +unbind '%' +bind = splitw -h -c '#{pane_current_path}' + +# Vi mode settings +bind-key -T copy-mode-vi Y send-keys -X copy-pipe 'yank > #{pane_tty}' +set-window-option -g mode-keys vi + +# Other settings +set-option -g escape-time 0 +set-option -g base-index 1 +set-window-option -g mouse on +set -g history-limit 100000 +EOF + +# Configure Git +RUN git config --global core.editor "vim" \ + && git config --global core.whitespace "fix,-indent-with-non-tab,trailing-space,cr-at-eol" \ + && git config --global core.pager "diff-so-fancy | less --tabs=4 -RFX" \ + && git config --global color.ui true \ + && git config --global color."diff-highlight".oldNormal "red bold" \ + && git config --global color."diff-highlight".oldHighlight "red bold 52" \ + && git config --global color."diff-highlight".newNormal "green bold" \ + && git config --global color."diff-highlight".newHighlight "green bold 22" \ + && git config --global color.diff.meta "11" \ + && git config --global color.diff.frag "magenta bold" \ + && git config --global color.diff.commit "yellow bold" \ + && git config --global color.diff.old "red bold" \ + && git config --global color.diff.new "green bold" \ + && git config --global color.diff.whitespace "red reverse" \ + && git config --global alias.lg "log --color --graph --pretty=format:'%Cred%h%Creset - %s %Cgreen(%cr) %C(bold blue)<%an>%Creset%C(auto)%d%Creset' --abbrev-commit --" \ + && git config --global http.sslVerify false \ + && git config --global pull.rebase true + +# Configure zsh +COPY --chown=root:root <<-"EOF" /root/.zshrc +export ZSH="/root/.oh-my-zsh" + +# Theme +ZSH_THEME="robbyrussell" + +# Plugins +plugins=( + git + z + zsh-autosuggestions + zsh-syntax-highlighting +) + +source $ZSH/oh-my-zsh.sh + +# Aliases +alias ll='ls -alF' +alias la='ls -A' +alias l='ls -CF' +alias vi='vim' + +# Enhanced history +HISTSIZE=10000 +SAVEHIST=10000 +setopt HIST_IGNORE_ALL_DUPS +setopt HIST_FIND_NO_DUPS +setopt INC_APPEND_HISTORY +EOF + +RUN set -euxo ; \ + curl --proto '=https' --tlsv1.2 -sSf https://just.systems/install.sh | bash -s -- --to /usr/local/bin + +# Set workspace directory +WORKDIR /sgl-workspace/sglang diff --git a/docker/Dockerfile.te_fp8.cu129 b/docker/Dockerfile.te_fp8.cu129 new file mode 100644 index 0000000000..582ff615de --- /dev/null +++ b/docker/Dockerfile.te_fp8.cu129 @@ -0,0 +1,65 @@ +ARG SGLANG_VERSION=latest +# you can build this docker through docker/Dockerfile.sglang.cu129 +FROM infix/sglang:cu129-latest AS sglang + +# we need to write this again after from +ARG SGLANG_VERSION +ARG MEGATRON_COMMIT=main + +RUN apt update +RUN apt install -y nvtop + +# TODO: change to pip install sglang-router after it has a new release +RUN pip install sglang-router --force-reinstall +RUN pip install git+https://github.com/fzyzcjy/torch_memory_saver.git --no-cache-dir --force-reinstall +RUN pip install ray[default] +RUN pip install httpx[http2] wandb pylatexenc blobfile accelerate "mcp[cli]" + +# mbridge +RUN pip install git+https://github.com/ISEEKYAN/mbridge.git --no-deps + +RUN TORCH_CUDA_ARCH_LIST="8.0;8.9;9.0;9.0a" pip install git+https://github.com/fanshiqing/grouped_gemm@v1.1.4 +# apex +RUN NVCC_APPEND_FLAGS="--threads 4" \ + pip -v install --disable-pip-version-check --no-cache-dir \ + --no-build-isolation \ + --config-settings "--build-option=--cpp_ext --cuda_ext --parallel 8" git+https://github.com/NVIDIA/apex.git +# transformer engine, we install with --no-deps to avoid installing torch and torch-extensions +RUN pip install pybind11 +RUN pip -v install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@3cd6870 +# flash attn +# the newest version megatron supports is v2.7.4.post1 +RUN MAX_JOBS=64 pip -v install flash-attn==2.7.4.post1 +RUN git clone https://github.com/Dao-AILab/flash-attention.git && cd flash-attention/ && git checkout 27f501d && cd hopper/ && python3 setup.py install +RUN python_path=`python3 -c "import site; print(site.getsitepackages()[0])"` && \ + mkdir -p $python_path/flash_attn_3 && \ + wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/27f501dbe011f4371bff938fe7e09311ab3002fa/hopper/flash_attn_interface.py + + +WORKDIR /root/ +RUN git clone -b v0.0.1-fp8 https://github.com/InfiXAI/Megatron-LM.git --recursive && \ + cd Megatron-LM && \ + pip install -e . + +# sandwitch norm for GLM models +COPY patch/te_fp8/megatron.patch /root/Megatron-LM/ +RUN cd Megatron-LM && \ + git checkout ${MEGATRON_COMMIT} && \ + git apply megatron.patch --3way && \ + if grep -R -n '^<<<<<<< ' .; then \ + echo "Patch failed to apply cleanly. Please resolve conflicts." && \ + exit 1; \ + fi && \ + rm megatron.patch + +# sglang patch +COPY patch/te_fp8/sglang.patch /sgl-workspace/sglang/ +RUN cd /sgl-workspace/sglang && \ + git apply sglang.patch && \ + if grep -R -n '^<<<<<<< ' .; then \ + echo "Patch failed to apply cleanly. Please resolve conflicts." && \ + exit 1; \ + fi && \ + rm sglang.patch + +RUN rm /root/.tmux.conf diff --git a/docker/patch/te_fp8/megatron.patch b/docker/patch/te_fp8/megatron.patch new file mode 100644 index 0000000000..90308ed0de --- /dev/null +++ b/docker/patch/te_fp8/megatron.patch @@ -0,0 +1,439 @@ +diff --git a/megatron/core/distributed/__init__.py b/megatron/core/distributed/__init__.py +index fe26e8b4..4451f277 100644 +--- a/megatron/core/distributed/__init__.py ++++ b/megatron/core/distributed/__init__.py +@@ -11,3 +11,15 @@ from .finalize_model_grads import finalize_model_grads + from .fsdp.mcore_fsdp_adapter import FullyShardedDataParallel + from .torch_fully_sharded_data_parallel import TorchFullyShardedDataParallel + from .torch_fully_sharded_data_parallel_config import TorchFullyShardedDataParallelConfig ++ ++# Backward compatibility patch for FSDP module reorganization ++import sys ++import importlib.util ++ ++spec = importlib.util.find_spec('megatron.core.distributed.fsdp.src.megatron_fsdp') ++if spec: ++ custom_fsdp = importlib.util.module_from_spec(spec) ++ spec.loader.exec_module(custom_fsdp) ++ sys.modules['megatron.core.distributed.custom_fsdp'] = custom_fsdp ++ if hasattr(custom_fsdp, 'MegatronFSDP'): ++ custom_fsdp.FullyShardedDataParallel = custom_fsdp.MegatronFSDP +diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py +index 99c3edc0..26ea5cb4 100644 +--- a/megatron/core/extensions/transformer_engine.py ++++ b/megatron/core/extensions/transformer_engine.py +@@ -404,6 +404,7 @@ class TELinear(te.pytorch.Linear): + ) + + for param in self.parameters(): ++ setattr(param, "parallel_mode", parallel_mode) + if is_expert: + # Reduce the gradient on the expert_data_parallel group for expert linear layers + setattr(param, "allreduce", not self.expert_parallel) +diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py +index 002edb92..f7273488 100755 +--- a/megatron/core/models/gpt/gpt_layer_specs.py ++++ b/megatron/core/models/gpt/gpt_layer_specs.py +@@ -80,6 +80,8 @@ def get_gpt_layer_with_transformer_engine_spec( + use_te_op_fuser: Optional[bool] = False, + use_kitchen: bool = False, + use_te_activation_func: bool = False, ++ post_self_attn_layernorm: bool = False, ++ post_mlp_layernorm: bool = False, + ) -> ModuleSpec: + """Use this spec to use lower-level Transformer Engine modules (required for fp8 training). + +@@ -182,9 +184,11 @@ def get_gpt_layer_with_transformer_engine_spec( + ), + ), + self_attn_bda=get_bias_dropout_add, ++ post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp, + pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp, + mlp=mlp, + mlp_bda=get_bias_dropout_add, ++ post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp, + sharded_state_dict_keys_map={ + "mlp.0.weight": "mlp.linear_fc1.layer_norm_weight", + "mlp.0.bias": "mlp.linear_fc1.layer_norm_bias", +diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py +index df9adc3e..2f4f544a 100644 +--- a/megatron/core/models/gpt/gpt_model.py ++++ b/megatron/core/models/gpt/gpt_model.py +@@ -443,7 +443,7 @@ class GPTModel(LanguageModule): + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + +- if mtp_in_postprocess: ++ if mtp_in_postprocess and labels is not None: + hidden_states = self.mtp( + input_ids=input_ids, + position_ids=position_ids, +diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py +index 57332ac3..f3abd642 100644 +--- a/megatron/core/parallel_state.py ++++ b/megatron/core/parallel_state.py +@@ -9,6 +9,7 @@ from typing import Callable, List, Optional + + import numpy as np + import torch ++import torch.distributed as dist + + from .utils import GlobalMemoryBuffer, is_torch_min_version + +@@ -163,6 +164,213 @@ def get_nccl_options(pg_name, nccl_comm_cfgs): + return None + + ++old_new_group = None ++ ++ ++def monkey_patch_torch_dist(): ++ print("Applying monkey patch to torch.distributed", flush=True) ++ global old_new_group ++ if old_new_group is not None: ++ return ++ ++ old_new_group = dist.new_group ++ ++ def new_group(*args, **kwargs): ++ group = old_new_group(*args, **kwargs) ++ # skip none nccl group. ++ if ( ++ len(args) >= 3 and args[2] == "gloo" or ++ "backend" in kwargs and kwargs["backend"] == "gloo" ++ ): ++ return group ++ ++ # Get ranks from arguments ++ if len(args) >= 1 and args[0] is not None: ++ ranks = args[0] ++ elif "ranks" in kwargs and kwargs["ranks"] is not None: ++ ranks = kwargs["ranks"] ++ else: ++ # If no ranks specified, use all ranks in world ++ ranks = list(range(dist.get_world_size())) ++ ++ if len(ranks) == 1: ++ return group ++ ++ group = ReloadableProcessGroup(group, ranks) ++ return group ++ ++ dist.new_group = new_group ++ ++ def get_new_function(func): ++ def new_function(*args, **kwargs): ++ args = ( ++ arg.group if isinstance(arg, ReloadableProcessGroup) else arg ++ for arg in args ++ ) ++ kwargs = { ++ k: (v.group if isinstance(v, ReloadableProcessGroup) else v) ++ for k, v in kwargs.items() ++ } ++ return func(*args, **kwargs) ++ return new_function ++ ++ dist.get_rank = get_new_function(dist.get_rank) ++ dist.get_world_size = get_new_function(dist.get_world_size) ++ dist.get_backend = get_new_function(dist.get_backend) ++ dist.get_global_rank = get_new_function(dist.get_global_rank) ++ dist.get_group_rank = get_new_function(dist.get_group_rank) ++ dist.get_process_group_ranks = get_new_function(dist.get_process_group_ranks) ++ ++ dist.all_reduce = get_new_function(dist.all_reduce) ++ dist.all_gather = get_new_function(dist.all_gather) ++ dist.all_gather_into_tensor = get_new_function(dist.all_gather_into_tensor) ++ dist.all_gather_object = get_new_function(dist.all_gather_object) ++ dist.all_to_all = get_new_function(dist.all_to_all) ++ dist.all_to_all_single = get_new_function(dist.all_to_all_single) ++ dist.broadcast = get_new_function(dist.broadcast) ++ dist.reduce = get_new_function(dist.reduce) ++ dist.reduce_scatter = get_new_function(dist.reduce_scatter) ++ dist.reduce_scatter_tensor = get_new_function(dist.reduce_scatter_tensor) ++ dist.scatter = get_new_function(dist.scatter) ++ dist.gather = get_new_function(dist.gather) ++ dist.barrier = get_new_function(dist.barrier) ++ dist.send = get_new_function(dist.send) ++ dist.recv = get_new_function(dist.recv) ++ dist._coalescing_manager = get_new_function(dist._coalescing_manager) ++ ++ # p2p ++ old_isend = dist.isend ++ old_irecv = dist.irecv ++ ++ dist.isend = get_new_function(dist.isend) ++ dist.irecv = get_new_function(dist.irecv) ++ ++ def get_new_p2pop_function(func): ++ def new_function(*args, **kwargs): ++ def convert(arg): ++ if isinstance(arg, ReloadableProcessGroup): ++ return arg.group ++ elif arg == dist.isend: ++ arg = old_isend ++ elif arg == dist.irecv: ++ arg = old_irecv ++ return arg ++ ++ args = (convert(arg) for arg in args) ++ kwargs = { ++ k: convert(v) ++ for k, v in kwargs.items() ++ } ++ return func(*args, **kwargs) ++ return new_function ++ ++ dist.P2POp.__new__ = get_new_p2pop_function(dist.P2POp.__new__) ++ dist.P2POp.__init__ = get_new_p2pop_function(dist.P2POp.__init__) ++ ++ ++ ++class ReloadableProcessGroup(torch.distributed.ProcessGroup): ++ GROUPS = [] ++ ++ def __init__(self, group, ranks): ++ super().__init__( ++ rank=dist.get_rank(group), ++ size=dist.get_world_size(group), ++ ) ++ #print(f"Creating ReloadableProcessGroup with ranks: {ranks}", flush=True) ++ self.group = group ++ self.group_info = { ++ "ranks": ranks, ++ } ++ ReloadableProcessGroup.GROUPS.append(self) ++ ++ def __getattr__(self, name): ++ return getattr(self.group, name) ++ ++ @staticmethod ++ def destroy_process_groups(): ++ for reloadable_group in ReloadableProcessGroup.GROUPS: ++ if reloadable_group.group is None: ++ continue ++ #print(f"Destroying process group: {reloadable_group.group_info['ranks']}") ++ dist.destroy_process_group(reloadable_group.group) ++ del reloadable_group.group ++ reloadable_group.group = None ++ ++ @staticmethod ++ def reload_process_groups(): ++ for reloadable_group in ReloadableProcessGroup.GROUPS: ++ if reloadable_group.group is not None: ++ continue ++ #print(f"Reloading process group: {reloadable_group.group_info['ranks']}") ++ group = old_new_group( ++ ranks=reloadable_group.group_info["ranks"], ++ backend="nccl" ++ ) ++ reloadable_group.group = group ++ ++ def rank(self) -> int: return self.group.rank() ++ def size(self) -> int: return self.group.size() ++ def name(self) -> str: return self.group.name() ++ ++ def shutdown(self) -> None: ++ if self.group is not None: ++ self.group.shutdown() ++ ++ def abort(self) -> None: ++ if self.group is not None: ++ self.group.abort() ++ ++ def _fwd(self, method, *args, **kwargs): ++ inner = self.group ++ if inner is None: ++ raise RuntimeError("ReloadableProcessGroup: inner PG is None, call reload() first.") ++ return getattr(inner, method)(*args, **kwargs) ++ ++ def barrier(self, *a, **kw): return self._fwd("barrier", *a, **kw) ++ def broadcast(self, *a, **kw): return self._fwd("broadcast", *a, **kw) ++ def allreduce(self, *a, **kw): return self._fwd("allreduce", *a, **kw) ++ def allreduce_coalesced(self, *a, **kw): return self._fwd("allreduce_coalesced", *a, **kw) ++ def reduce(self, *a, **kw): return self._fwd("reduce", *a, **kw) ++ def allgather(self, *a, **kw): return self._fwd("allgather", *a, **kw) ++ def _allgather_base(self, *a, **kw): return self._fwd("_allgather_base", *a, **kw) ++ def allgather_coalesced(self, *a, **kw): return self._fwd("allgather_coalesced", *a, **kw) ++ def allgather_into_tensor_coalesced(self, *a, **kw): return self._fwd("allgather_into_tensor_coalesced", *a, **kw) ++ def gather(self, *a, **kw): return self._fwd("gather", *a, **kw) ++ def scatter(self, *a, **kw): return self._fwd("scatter", *a, **kw) ++ def reduce_scatter(self, *a, **kw): return self._fwd("reduce_scatter", *a, **kw) ++ def _reduce_scatter_base(self, *a, **kw): return self._fwd("_reduce_scatter_base", *a, **kw) ++ def reduce_scatter_tensor_coalesced(self, *a, **kw): return self._fwd("reduce_scatter_tensor_coalesced", *a, **kw) ++ def alltoall_base(self, *a, **kw): return self._fwd("alltoall_base", *a, **kw) ++ def alltoall(self, *a, **kw): return self._fwd("alltoall", *a, **kw) ++ def send(self, *a, **kw): return self._fwd("send", *a, **kw) ++ def recv(self, *a, **kw): return self._fwd("recv", *a, **kw) ++ def recv_anysource(self, *a, **kw): return self._fwd("recv_anysource", *a, **kw) ++ ++ def _start_coalescing(self, *a, **kw): return self._fwd("_start_coalescing", *a, **kw) ++ def _end_coalescing(self, *a, **kw): return self._fwd("_end_coalescing", *a, **kw) ++ def _get_backend_name(self): return self._fwd("_get_backend_name") ++ def _get_backend(self, *a, **kw): return self._fwd("_get_backend", *a, **kw) ++ def _set_default_backend(self, *a, **kw): return self._fwd("_set_default_backend", *a, **kw) ++ @property ++ def bound_device_id(self): return self.group.bound_device_id ++ @bound_device_id.setter ++ def bound_device_id(self, dev): self.group.bound_device_id = dev ++ ++ ++def destroy_process_groups(): ++ """Destroy all reloadable process groups.""" ++ ReloadableProcessGroup.destroy_process_groups() ++ ++ ++def reload_process_groups(): ++ """Reload all reloadable process groups.""" ++ ReloadableProcessGroup.reload_process_groups() ++ ++ ++monkey_patch_torch_dist() ++ ++ + def create_group( + ranks=None, + timeout=None, +diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py +index 63ee9d1f..b90b744c 100644 +--- a/megatron/core/pipeline_parallel/p2p_communication.py ++++ b/megatron/core/pipeline_parallel/p2p_communication.py +@@ -26,22 +26,22 @@ def _batched_p2p_ops( + ops = [] + if tensor_send_prev is not None: + send_prev_op = torch.distributed.P2POp( +- torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group ++ torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, + ) + ops.append(send_prev_op) + if tensor_recv_prev is not None: + recv_prev_op = torch.distributed.P2POp( +- torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group ++ torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, + ) + ops.append(recv_prev_op) + if tensor_send_next is not None: + send_next_op = torch.distributed.P2POp( +- torch.distributed.isend, tensor_send_next, next_pipeline_rank, group ++ torch.distributed.isend, tensor_send_next, next_pipeline_rank, + ) + ops.append(send_next_op) + if tensor_recv_next is not None: + recv_next_op = torch.distributed.P2POp( +- torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group ++ torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, + ) + ops.append(recv_next_op) + if len(ops) > 0: +diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py +index 6f557e1f..b295fd35 100644 +--- a/megatron/core/transformer/transformer_config.py ++++ b/megatron/core/transformer/transformer_config.py +@@ -173,6 +173,9 @@ class TransformerConfig(ModelParallelConfig): + qk_layernorm: bool = False + """Whether to apply `normalization` type of normalization to the query and key embeddings.""" + ++ post_self_attn_layernorm: bool = False ++ post_mlp_layernorm: bool = False ++ + test_mode: bool = False + """Whether to run real-time tests.""" + +diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py +index 84f22bde..b4807d26 100644 +--- a/megatron/core/transformer/transformer_layer.py ++++ b/megatron/core/transformer/transformer_layer.py +@@ -224,6 +224,7 @@ class TransformerLayerSubmodules: + input_layernorm: Union[ModuleSpec, type] = IdentityOp + self_attention: Union[ModuleSpec, type] = IdentityOp + self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp ++ post_self_attn_layernorm: Union[ModuleSpec, type] = IdentityOp + + pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp + cross_attention: Union[ModuleSpec, type] = IdentityOp +@@ -232,6 +233,7 @@ class TransformerLayerSubmodules: + pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp + mlp: Union[ModuleSpec, type] = IdentityOp + mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp ++ post_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp + + # Mapping for sharded tensor keys to be applied in `sharded_state_dict` method + sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict) +@@ -336,6 +338,14 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): + # [Module 3: BiasDropoutFusion] + self.self_attn_bda = build_module(submodules.self_attn_bda) + ++ self.post_self_attn_layernorm = build_module( ++ submodules.post_self_attn_layernorm, ++ config=self.config, ++ hidden_size=self.config.hidden_size, ++ eps=self.config.layernorm_epsilon, ++ ) ++ ++ + # [Module 4: Post SelfAttention] Optional Layernorm after self-attn + self.pre_cross_attn_layernorm = build_module( + submodules.pre_cross_attn_layernorm, +@@ -399,6 +409,13 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): + # [Module 9: BiasDropoutFusion] + self.mlp_bda = build_module(submodules.mlp_bda) + ++ self.post_mlp_layernorm = build_module( ++ submodules.post_mlp_layernorm, ++ config=self.config, ++ hidden_size=self.config.hidden_size, ++ eps=self.config.layernorm_epsilon ++ ) ++ + self.recompute_input_layernorm = False + self.recompute_pre_mlp_layernorm = False + self.recompute_mlp = False +@@ -535,6 +552,11 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): + attention_output_with_bias[0] + ) + ++ attention_output, attention_output_bias = attention_output_with_bias ++ attention_output = self.post_self_attn_layernorm(attention_output) ++ attention_output_with_bias = (attention_output, attention_output_bias) ++ ++ + # TODO: could we move `bias_dropout_add_exec_handler` itself + # inside the module provided in the `bias_dropout_add_spec` module? + nvtx_range_push(suffix="self_attn_bda") +@@ -635,6 +657,10 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): + else: + mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) + ++ mlp_output, mlp_output_bias = mlp_output_with_bias ++ mlp_output = self.post_mlp_layernorm(mlp_output) ++ mlp_output_with_bias = (mlp_output, mlp_output_bias) ++ + if self.recompute_pre_mlp_layernorm: + # discard the output of the pre-mlp layernorm and register the recompute + # as a gradient hook of mlp_output_with_bias[0] +diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py +index 9381b0e4..8556856a 100644 +--- a/megatron/training/arguments.py ++++ b/megatron/training/arguments.py +@@ -1237,6 +1237,8 @@ def core_transformer_config_from_args(args, config_class=None): + + kw_args['inference_sampling_seed'] = args.seed + ++ kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm ++ kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm + # handle quantization config + # NOTE: Kitchen arguments are only added to the namespace when + # Kitchen library is available. +@@ -1547,6 +1549,10 @@ def _add_network_size_args(parser): + action='store_true', + help='If set, use original BERT residula connection ' + 'ordering.') ++ group.add_argument('--post-self-attn-layernorm', action='store_true', ++ help='If set, use post self attention layernorm.') ++ group.add_argument('--post-mlp-layernorm', action='store_true', ++ help='If set, use post MLP layernorm.') + group.add_argument('--openai-gelu', action='store_true', + help='Use OpenAIs GeLU implementation. This option' + 'should not be used unless for backward compatibility' + + diff --git a/docker/patch/te_fp8/sglang.patch b/docker/patch/te_fp8/sglang.patch new file mode 100644 index 0000000000..20901415fd --- /dev/null +++ b/docker/patch/te_fp8/sglang.patch @@ -0,0 +1,236 @@ +diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py +index 1b96ae678..3b7023c55 100644 +--- a/python/sglang/srt/configs/model_config.py ++++ b/python/sglang/srt/configs/model_config.py +@@ -454,14 +454,14 @@ class ModelConfig: + ).lower() + + # Detect which checkpoint is it +- for _, method in QUANTIZATION_METHODS.items(): +- quantization_override = method.override_quantization_method( +- quant_cfg, self.quantization +- ) +- if quantization_override: +- quant_method = quantization_override +- self.quantization = quantization_override +- break ++ # for _, method in QUANTIZATION_METHODS.items(): ++ # quantization_override = method.override_quantization_method( ++ # quant_cfg, self.quantization ++ # ) ++ # if quantization_override: ++ # quant_method = quantization_override ++ # self.quantization = quantization_override ++ # break + + # Verify quantization configurations. + if self.quantization is None: +diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py +index a6bcb0b5b..1ef6abe72 100644 +--- a/python/sglang/srt/entrypoints/http_server.py ++++ b/python/sglang/srt/entrypoints/http_server.py +@@ -264,6 +264,10 @@ async def validate_json_request(raw_request: Request): + + + @app.get("/health") ++async def health(request: Request) -> Response: ++ return Response(status_code=200) ++ ++ + @app.get("/health_generate") + async def health_generate(request: Request) -> Response: + """ +diff --git a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py +index 372717bf9..40665cc90 100644 +--- a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py ++++ b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py +@@ -190,6 +190,7 @@ class DeepEPBuffer: + f"Consider using --deepep-config to change the behavior." + ) + ++ num_qps_per_rank = 20 + cls._buffer = Buffer( + group, + num_nvl_bytes, +diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py +index 31a2c2e..ca1b149 100644 +--- a/python/sglang/srt/layers/quantization/fp8.py ++++ b/python/sglang/srt/layers/quantization/fp8.py +@@ -355,8 +355,8 @@ class Fp8LinearMethod(LinearMethodBase): + return + else: + weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data +- layer.weight = Parameter(weight, requires_grad=False) +- layer.weight_scale_inv = Parameter(weight_scale, requires_grad=False) ++ # layer.weight = Parameter(weight, requires_grad=False) ++ # layer.weight_scale_inv = Parameter(weight_scale, requires_grad=False) + else: + layer.weight = Parameter(layer.weight.data, requires_grad=False) + +diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py +index 8daa8af..078fec0 100644 +--- a/python/sglang/srt/managers/scheduler.py ++++ b/python/sglang/srt/managers/scheduler.py +@@ -1436,7 +1436,7 @@ class Scheduler( + + if memory_leak: + msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}" +- raise ValueError(msg) ++ # raise ValueError(msg) + + if self.disaggregation_mode == DisaggregationMode.DECODE: + req_total_size = ( +@@ -1451,7 +1451,7 @@ class Scheduler( + f"available_size={len(self.req_to_token_pool.free_slots)}, " + f"total_size={self.req_to_token_pool.size}\n" + ) +- raise ValueError(msg) ++ # raise ValueError(msg) + + if ( + self.enable_metrics +@@ -1898,6 +1898,7 @@ class Scheduler( + speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens, + require_mlp_tp_gather=require_mlp_tp_gather(self.server_args), + disable_overlap_schedule=self.server_args.disable_overlap_schedule, ++ offload_tags=self.offload_tags, + ) + + @staticmethod +@@ -1912,6 +1913,7 @@ class Scheduler( + speculative_num_draft_tokens, + require_mlp_tp_gather: bool, + disable_overlap_schedule: bool, ++ offload_tags: set[str], + ): + # Check if other DP workers have running batches + if local_batch is None: +@@ -1942,7 +1944,7 @@ class Scheduler( + ) + + tbo_preparer = TboDPAttentionPreparer() +- if disable_overlap_schedule: ++ if len(offload_tags) == 0 and disable_overlap_schedule: + group = tp_group.device_group + device = tp_group.device + else: +diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py +index 50ac39f88..33782a8cd 100644 +--- a/python/sglang/srt/managers/tokenizer_manager.py ++++ b/python/sglang/srt/managers/tokenizer_manager.py +@@ -1045,10 +1045,15 @@ class TokenizerManager: + request: Optional[fastapi.Request] = None, + ) -> Tuple[bool, str]: + self.auto_create_handle_loop() +- assert ( +- self.server_args.dp_size == 1 +- ), "dp_size must be 1 for init parameter update group" +- result = (await self.init_weights_update_group_communicator(obj))[0] ++ results = await self.init_weights_update_group_communicator(obj) ++ if self.server_args.dp_size == 1: ++ result = results[0] ++ return result.success, result.message ++ else: ++ all_success = all([r.success for r in results]) ++ all_message = [r.message for r in results] ++ all_message = " | ".join(all_message) ++ return all_success, all_message + return result.success, result.message + + async def update_weights_from_distributed( +@@ -1057,9 +1062,6 @@ class TokenizerManager: + request: Optional[fastapi.Request] = None, + ) -> Tuple[bool, str]: + self.auto_create_handle_loop() +- assert ( +- self.server_args.dp_size == 1 or self.server_args.enable_dp_attention +- ), "dp_size must be 1 or dp attention must be enabled for update weights from distributed" + + if obj.abort_all_requests: + self.abort_request(abort_all=True) +@@ -1067,8 +1069,15 @@ class TokenizerManager: + # This means that weight sync + # cannot run while requests are in progress. + async with self.model_update_lock.writer_lock: +- result = (await self.update_weights_from_distributed_communicator(obj))[0] +- return result.success, result.message ++ results = await self.update_weights_from_distributed_communicator(obj) ++ if self.server_args.dp_size == 1: ++ result = results[0] ++ return result.success, result.message ++ else: ++ all_success = all([r.success for r in results]) ++ all_message = [r.message for r in results] ++ all_message = " | ".join(all_message) ++ return all_success, all_message + + async def update_weights_from_tensor( + self, +diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py +index 303919505..df2407eb6 100644 +--- a/python/sglang/srt/model_executor/cuda_graph_runner.py ++++ b/python/sglang/srt/model_executor/cuda_graph_runner.py +@@ -759,6 +759,22 @@ class CudaGraphRunner: + ) + if forward_batch.forward_mode.is_idle() and forward_batch.spec_info is not None: + forward_batch.spec_info.custom_mask = self.custom_mask ++ ++ if forward_batch.forward_mode.is_target_verify() and bs - raw_bs > 0: ++ # pad the spec_info custom mask ++ spec_info = forward_batch.spec_info ++ pad_len = ( ++ (bs - raw_bs) ++ * spec_info.draft_token_num ++ * (spec_info.draft_token_num + self.seq_len_fill_value) ++ ) ++ pad_mask = torch.full( ++ (pad_len,), ++ True, ++ device=spec_info.custom_mask.device, ++ ) ++ spec_info.custom_mask = torch.cat([spec_info.custom_mask, pad_mask], dim=0) ++ + # Attention backend + self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( + bs, +diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py +index ee83c2d9c..5145d7f78 100644 +--- a/python/sglang/srt/model_executor/model_runner.py ++++ b/python/sglang/srt/model_executor/model_runner.py +@@ -22,6 +22,7 @@ import os + import time + from dataclasses import dataclass + from typing import List, Optional, Tuple, Union ++from contextlib import nullcontext + + import torch + import torch.distributed as dist +@@ -667,7 +668,7 @@ class ModelRunner: + monkey_patch_vllm_parallel_state() + monkey_patch_isinstance_for_vllm_base_layer() + +- with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS): ++ with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS) if not self.is_draft_worker else nullcontext(): + self.model = get_model( + model_config=self.model_config, + load_config=self.load_config, +diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py +index 67ef6ca79..b6ee944de 100644 +--- a/python/sglang/srt/models/glm4_moe.py ++++ b/python/sglang/srt/models/glm4_moe.py +@@ -695,7 +695,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer): + layer_scatter_modes=self.layer_scatter_modes, + input_layernorm=self.input_layernorm, + post_attention_layernorm=self.post_attention_layernorm, +- allow_reduce_scatter=True, ++ allow_reduce_scatter=False, + ) + + def forward( +@@ -1108,5 +1108,4 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM): + ) + weight_loader(param, loaded_weight) + +- + EntryClass = [Glm4MoeForCausalLM] +