本文档按特性模块组织,每个模块包含概述 + 接口详细说明。
HyperParallel 提供 FSDP/HSDP 数据并行能力,支持参数/梯度/优化器状态的分布式切分,显著降低单卡内存占用。
对模型应用 FSDP/HSDP 参数切分。
fully_shard(
module: Module,
mesh: DeviceMesh,
*,
reshard_after_forward: bool = True,
mixed_precision: Optional[MixedPrecisionPolicy] = None,
offload_policy: Optional[OffloadPolicy] = None,
comm_fusion: bool = True,
comm_fusion_zero_copy: Optional[bool] = None,
) -> Module参数:
| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
module |
Module |
— | 要切分的模块 |
mesh |
DeviceMesh |
— | DP 维度的 DeviceMesh |
reshard_after_forward |
bool |
True |
正向后是否 reshard(节省内存) |
mixed_precision |
MixedPrecisionPolicy |
None |
混合精度策略 |
offload_policy |
OffloadPolicy |
None |
Offload 策略 |
comm_fusion |
bool |
True |
通信融合 |
comm_fusion_zero_copy |
bool |
None |
通信融合零拷贝(PyTorch 默认 True,MindSpore 默认 False) |
返回值: 切分后的模块(原地修改,返回同一对象)。
HSDP 模块封装类。
class HSDPModuleHSDPModule 封装了 HSDP 参数切分的核心逻辑,包括参数 unshard/reshard、梯度同步等。
HSDP 梯度同步流管理。
hsdp_sync_stream()返回 HSDP 使用的梯度同步流,确保梯度 all-reduce 与计算流正确同步。
DTensor 是 HyperParallel 的核心分布式张量抽象,封装 local shard + DeviceMesh + Placements。
分布式张量类。
class DTensor核心方法:
| 方法 | 说明 |
|---|---|
from_local(local_tensor, layout) |
从本地张量创建 DTensor |
to_local() |
获取本地 shard 张量 |
redistribute(device_mesh, placements) |
重分布到目标 layout |
full_tensor() |
获取完整(聚合后的)张量 |
is_partial() |
检查是否为 partial 状态 |
reduce_partial() |
对 partial DTensor 执行 reduce |
张量到 mesh 的排布映射。
Layout(mesh_shape, mesh_dim_names)参数:
| 参数 | 类型 | 说明 |
|---|---|---|
mesh_shape |
tuple |
mesh 形状,如 (4, 8) |
mesh_dim_names |
tuple |
mesh 维度名,如 ("dp", "tp") |
使用方式:
layout = Layout((dp, tp), ("dp", "tp"))
x_layout = layout("dp", "tp") # 指定各维度的排布
w_layout = layout("tp", "None") # "None" 表示不切分
out_layout = layout() # 空 layout 表示 Replicate多维设备拓扑管理。
class DeviceMesh核心方法:
| 方法 | 说明 |
|---|---|
__getitem__(dim_name) |
切片子 mesh,如 mesh["dp"] |
get_device_num_along_axis(dim_name) |
获取指定维度设备数 |
创建 DeviceMesh。
init_device_mesh(
device_type: str,
mesh_shape: tuple,
mesh_dim_names: tuple,
) -> DeviceMesh参数:
| 参数 | 类型 | 说明 |
|---|---|---|
device_type |
str |
设备类型,如 "npu" 或 "cuda" |
mesh_shape |
tuple |
mesh 形状 |
mesh_dim_names |
tuple |
mesh 维度名 |
获取当前活跃的 DeviceMesh。
get_current_mesh() -> DeviceMesh将模块的参数按指定 layout 分布。
distribute_module(module, device_mesh, ...)分片参数初始化。
init_parameters(model) -> Module对模型参数按 DTensor layout 进行分片初始化,避免先初始化完整参数再切分导致的内存峰值。
空权重初始化上下文。
init_empty_weights()延迟初始化上下文,创建参数时不分配内存,后续通过 init_parameters 分片初始化。
在设备上初始化参数。
init_on_device()分布式随机数种子控制(PyTorch parity)。
manual_seed(seed: int)设置分布式随机数种子,确保各 rank 的随机数生成一致。
声明式并行策略接口。
shard_module(
module: Module,
sharding_plan: dict,
) -> Module参数:
| 参数 | 类型 | 说明 |
|---|---|---|
module |
Module |
要切分的模块 |
sharding_plan |
dict |
切分配置,包含 forward(input/output)和 parameter |
sharding_plan 格式:
sharding_plan = {
"forward": {
"input": (x_layout,),
"output": (out_layout,),
},
"parameter": {
"weight": w_layout,
},
}自定义并行接入 DTensor 并行流程。
custom_shard(...)自定义分布式 autograd 函数基类。
class DFunction(platform.Function):
_op_name: str = None
@staticmethod
def forward(ctx, *args, **kwargs) -> Tensor: ...
@staticmethod
def backward(ctx, *grad_outputs) -> ...: ...
@classmethod
def apply(cls, *args, **kwargs) -> Tensor | DTensor: ...详细文档: 参见 DFunction 文档。
并行化的值与梯度计算。
parallelize_value_and_grad(...)梯度 hook 中绕过 DTensor dispatch 的标记。
class SkipDTensorDispatch在 FSDP/HSDP 的梯度 hook 中使用,直接操作 local tensor 而不经过 DTensor dispatch 系统。
列切分并行策略。
class ColwiseParallel(ParallelStyle):
input_layouts: Placement = Replicate()
output_layouts: Placement = Shard(-1)
use_local_output: bool = True行切分并行策略。
class RowwiseParallel(ParallelStyle):
input_layouts: Placement = Shard(-1)
output_layouts: Placement = Replicate()
use_local_output: bool = True序列并行策略。
class SequenceParallel(ParallelStyle)模块输入准备钩子。
class PrepareModuleInput(ParallelStyle)模块输出准备钩子。
class PrepareModuleOutput(ParallelStyle)模块输入/输出准备钩子。
class PrepareModuleInputOutput(ParallelStyle)并行策略基类。
class ParallelStyle所有并行策略的抽象基类,定义 apply(module, device_mesh) 接口。
声明式 TP 应用接口。
parallelize_module(
module: Module,
device_mesh: DeviceMesh,
parallelize_plan: dict | ParallelStyle,
*,
src_data_rank: int = 0,
) -> Module参数:
| 参数 | 类型 | 说明 |
|---|---|---|
module |
Module |
要并行化的模块 |
device_mesh |
DeviceMesh |
1-D TP DeviceMesh |
parallelize_plan |
dict 或 ParallelStyle |
并行策略配置,支持 fnmatch glob pattern |
src_data_rank |
int |
数据源 rank(默认 0) |
流水线 stage 封装。
class PipelineStage:
def __init__(
self,
module: Module,
stage_index: int,
stage_num: int,
)参数:
| 参数 | 类型 | 说明 |
|---|---|---|
module |
Module |
stage 对应的子模块 |
stage_index |
int |
当前 stage 的索引 |
stage_num |
int |
总 stage 数 |
1F1B 调度策略。
class Schedule1F1B:
def __init__(
self,
stage: PipelineStage,
micro_batch_num: int,
)GPipe 调度策略。
class ScheduleGPipe:
def __init__(
self,
stage: PipelineStage,
micro_batch_num: int,
)VPP(交错 1F1B)调度策略。
class ScheduleInterleaved1F1B:
def __init__(
self,
stages: list,
micro_batch_num: int,
*,
overlap_p2p: bool = False,
overlap_b_f: bool = False,
)参数:
| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
stages |
list |
— | PipelineStage 列表 |
micro_batch_num |
int |
— | micro-batch 数 |
overlap_p2p |
bool |
False |
是否启用 P2P prefetch |
overlap_b_f |
bool |
False |
是否启用 B/F 通算掩盖 |
调度单元抽象。
class MetaStep描述调度单元,包含 step_type 和可选的 sub_steps。
调度单元类型枚举。
class MetaStepType(enum.Enum):
FORWARD = "F"
BACKWARD = "B"
OVERLAP_F_B = "OFB"
OVERLAP_B_F = "OBF"
FWD_RECV = "FR"
BWD_RECV = "BR"
FWD_SEND = "FS"
BWD_SEND = "BS"批量维度规格。
class BatchDimSpec双线程通算掩盖协调器。
class CommComputeOverlap:
def wrap_dispatch(self, dispatch_fn, ...) -> Callable
def wrap_combine(self, combine_fn, is_last_layer: bool = False) -> Callable
def run(self, fwd_fn: Callable, bwd_fn: Callable) -> NoneCOMM-first rendezvous 原语。
class HookCoordinator:
def enable(self) -> None
def disable(self) -> None
def rendezvous(self, hook_name: str, role: HookRole) -> None
class HookRole(enum.Enum):
COMM = "COMM"
COMPUTE = "COMPUTE"基础上下文并行。
class ContextParallel(ParallelStyle)异步上下文并行。
class AsyncContextParallel(ParallelStyle)| 类 | 说明 |
|---|---|
DSAIndexerContextParallel |
DSA 索引器 CP |
AsyncDSAIndexerContextParallel |
异步 DSA 索引器 CP |
DSAIndexerLossContextParallel |
DSA 索引器 + Loss CP |
AsyncDSAIndexerLossContextParallel |
异步 DSA 索引器 + Loss CP |
DSASparseAttentionContextParallel |
DSA Sparse Attention CP |
AsyncDSASparseAttentionContextParallel |
异步 DSA Sparse Attention CP |
所有 DSA 类均继承 ParallelStyle,支持 apply(module, device_mesh) 接口。
标准 all-to-all EP。
class ExpertParallel(BaseExpertParallel)每个 rank 持有 num_experts / ep_degree 个本地专家,通过 differentiable all-to-all 实现 token dispatch/combine。
Mesh 需求: 1-D mesh,如 mesh_dim_names=("ep",)。
TP-only 权重切分(EP degree = 1 时使用)。
class TensorParallel(BaseExpertParallel)EP+TP 二维并行。
class ExpertTensorParallel(BaseExpertParallel)Mesh 需求: 2-D mesh,如 mesh_dim_names=("ep", "tp")。
初始化进程组。
init_process_group(backend: str, ...)销毁进程组。
destroy_process_group()获取进程组内所有 rank。
get_process_group_ranks(group) -> list获取当前进程组的 backend 类型。
get_backend() -> str从父进程组中分裂出子进程组。
split_group(group, split_size, ...)获取进程组内的 local rank。
get_group_local_rank(group) -> int标记已创建的进程组。
mark_created_groups(groups)标准 AdamW 优化器。
class AdamW(torch.optim.Optimizer):
def __init__(
self,
params: list,
lr: float = 1e-3,
weight_decay: float = 0.01,
betas: tuple = (0.9, 0.999),
eps: float = 1e-8,
...
)Muon(momentum-based)优化器。
class Muon(torch.optim.Optimizer):
def __init__(
self,
params: list,
lr: float = 0.02,
momentum: float = 0.95,
ns_steps: int = 5,
nesterov: bool = True,
weight_decay: float = 0.1,
...
)链式优化器(Muon+AdamW 组合)。
class ChainedOptimizer:
def __init__(
self,
model: nn.Module,
optimizers: Dict[str, Optimizer],
flatten: bool = False,
)
def step(self) -> None
def zero_grad(self, set_to_none: bool = True) -> None
def __iter__(self)优化器工厂函数。
get_hyper_optimizer(
model: nn.Module,
muon_params: List[Dict[str, Any]],
adamw_params: List[Dict[str, Any]],
muon_kwargs: Optional[Dict[str, Any]] = None,
adamw_kwargs: Optional[Dict[str, Any]] = None,
) -> ChainedOptimizer参数:
| 参数 | 类型 | 说明 |
|---|---|---|
model |
nn.Module |
模型 |
muon_params |
List[Dict] |
Muon 参数组,空列表禁用 Muon |
adamw_params |
List[Dict] |
AdamW 参数组,空列表禁用 AdamW |
muon_kwargs |
Dict |
Muon 配置,支持 muon_ 前缀自动去前缀 |
adamw_kwargs |
Dict |
AdamW 配置,支持 adamw_ 前缀自动去前缀 |
学习率调度器工厂函数。
get_hyper_lr_scheduler(
optimizer: ChainedOptimizer,
total_steps: int,
warmup_steps: int = 0,
warmup_ratio: float = 0.0,
decay_style: str = "cosine",
lr: float = 1e-4,
lr_min: float = 1e-7,
lr_start: float = 0.0,
lr_decay_ratio: float = 1.0,
wsd_decay_steps: Optional[int] = None,
lr_wsd_decay_style: str = "exponential",
override_opt_param_scheduler: bool = False,
) -> LRSchedulersContainer支持的 decay_style: "constant" / "linear" / "cosine" / "WSD"
支持的 WSD decay_style: "linear" / "cosine" / "exponential" / "minus_sqrt"
函数式激活重计算。
checkpoint(
function,
*args,
swap_inputs: bool = False,
policy_fn: Optional[Callable] = None,
context_fn: Optional[Callable] = None,
group_swap: bool = False,
**kwargs,
)参数:
| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
function |
callable |
— | 要 checkpoint 的函数 |
swap_inputs |
bool |
False |
是否将输入 swap 到 CPU |
policy_fn |
callable |
None |
逐 tensor 重计算策略 |
context_fn |
callable |
None |
上下文工厂(forward_ctx, recompute_ctx) |
group_swap |
bool |
False |
是否启用 swap 融合 |
函数式激活 swap。
swap(
function,
*args,
policy_fn: Optional[Callable] = None,
group_swap: bool = False,
**kwargs,
)参数:
| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
function |
callable |
— | 要 swap 的函数 |
policy_fn |
callable |
None |
逐 tensor swap 策略 |
group_swap |
bool |
False |
是否启用 swap 融合 |
模块级 checkpoint 装饰器。
checkpoint_wrapper(module, policy="full", ...)模块级 swap 装饰器。
swap_wrapper(module, offload_to="cpu", ...)单 tensor swap 装饰器。
swap_tensor_wrapper(...)重计算策略枚举。
class CheckpointPolicy(enum.Enum):
MUST_SAVE = 0 # 必须保存
PREFER_SAVE = 1 # 优先保存
MUST_RECOMPUTE = 2 # 必须重计算
PREFER_RECOMPUTE = 3 # 优先重计算
MUST_SWAP = 4 # 必须 swap(需要 SwapManager)Swap 分组管理器(单例)。
class SwapManager:
def add_storage(self, group_name: str, storage: Storage) -> None
def ensure_group(self, group_name: str) -> None
def launch_offload(self, group_name: str, copy_stream=None) -> None
def protect_alias_tensors(self, group_name: str, tensors: Any) -> None获取当前平台对象。
get_platform() -> Platform返回 PyTorch 或 MindSpore 平台实现,用于访问平台特定功能。
__all__ = [
"get_platform", "DFunction", "fully_shard", "hsdp_sync_stream", "HSDPModule",
"DTensor", "Layout", "DeviceMesh", "init_device_mesh", "get_current_mesh",
"distribute_module", "init_parameters", "init_empty_weights", "init_on_device",
"shard_module", "custom_shard", "parallelize_value_and_grad", "SkipDTensorDispatch",
"MetaStep", "MetaStepType", "BatchDimSpec", "PipelineStage", "ScheduleInterleaved1F1B",
"init_process_group", "destroy_process_group", "get_process_group_ranks", "get_backend",
"split_group", "get_group_local_rank", "mark_created_groups",
"ContextParallel", "AsyncContextParallel",
"AsyncDSAIndexerContextParallel", "AsyncDSAIndexerLossContextParallel",
"AsyncDSASparseAttentionContextParallel",
"DSAIndexerContextParallel", "DSAIndexerLossContextParallel", "DSASparseAttentionContextParallel",
"ColwiseParallel", "RowwiseParallel", "SequenceParallel",
"PrepareModuleInput", "PrepareModuleInputOutput", "PrepareModuleOutput",
"ParallelStyle", "parallelize_module", "manual_seed",
]