DeepSeek-V3.2 源码架构详解
本文档面向初学者,详细解析 DeepSeek-V3.2 模型的源码实现,帮助读者理解其核心架构和运行机制。
精细度:本文档解析到每一个类、每一个函数、每一行关键代码的粒度。
目录
- 整体架构概览
- 基础组件详解
- 2.1 ModelArgs 配置类
- 2.2 ParallelEmbedding 并行嵌入层
- 2.3 Linear 及并行变体
- 2.4 RMSNorm 与 LayerNorm
- 位置编码详解
- 3.1 precompute_freqs_cis 预计算
- 3.2 apply_rotary_emb 应用旋转
- 3.3 YaRN 长度外推
- MLA 注意力机制
- 4.1 MLA 类结构
- 4.2 Query 路径详解
- 4.3 KV 路径详解
- 4.4 Prefill vs Decode 模式
- Indexer 稀疏注意力
- 5.1 Indexer 类结构
- 5.2 前向传播流程
- 5.3 Hadamard 变换
- 前馈网络
- MoE 混合专家系统
- 7.1 Gate 路由器
- 7.2 Expert 专家层
- 7.3 MoE 整体结构
- Block 与 Transformer
- 8.1 Block 结构
- 8.2 Transformer 主类
- 8.3 残差连接策略
- FP8 量化与 CUDA Kernel
- 推理流程
- 权重转换流程
1. 整体架构概览
1.1 模型架构图
graph TB
subgraph "DeepSeek-V3.2 Transformer"
Input[输入 Tokens] --> Embed[ParallelEmbedding
词嵌入层]
Embed --> Block1[Block 0
Dense FFN]
Block1 --> Block2[Block 1
Dense FFN]
Block2 --> Block3[Block 2
Dense FFN]
Block3 --> BlockN[Block 3-60
MoE FFN]
BlockN --> Norm[RMSNorm
最终归一化]
Norm --> Head[ColumnParallelLinear
输出投影]
Head --> Output[输出 Logits]
end
subgraph "每个 Block 内部结构"
direction LR
X[输入 x] --> AttnNorm[RMSNorm]
AttnNorm --> MLA[MLA 注意力]
MLA --> FFNNorm[RMSNorm]
FFNNorm --> FFN[MLP/MoE]
FFN --> Y[输出]
end
1.2 文件结构与代码行数
1inference/
2├── model.py # 923 行 - 核心模型定义
3├── generate.py # 187 行 - 推理生成逻辑
4├── kernel.py # 275 行 - Kernel 实现(TileLang JIT:act_quant / fp8_gemm / fp8_index)
5├── convert.py # 101 行 - 权重转换脚本(HF → 本仓库推理格式)
6├── config_671B_v3.2.json # 671B 推理 demo 的 ModelArgs 配置(与 generate.py 匹配)
7├── requirements.txt # 推理依赖(含 tilelang / fast_hadamard_transform)
8└── README.md # 推理 demo 使用说明
9encoding/
10└── encoding_dsv32.py # Chat/Tool 调用的 prompt 编码(与推理代码独立)
说明:本文档中若写
model.py/generate.py/kernel.py/convert.py,默认指inference/目录下同名文件。
1.3 类依赖关系图
classDiagram
Transformer --> ParallelEmbedding
Transformer --> Block
Transformer --> RMSNorm
Transformer --> ColumnParallelLinear
Block --> MLA
Block --> MLP
Block --> MoE
Block --> RMSNorm
MLA --> Linear
MLA --> ColumnParallelLinear
MLA --> RowParallelLinear
MLA --> RMSNorm
MLA --> Indexer
Indexer --> Linear
Indexer --> LayerNorm
MoE --> Gate
MoE --> Expert
MoE --> MLP
Expert --> Linear
MLP --> ColumnParallelLinear
MLP --> RowParallelLinear
ColumnParallelLinear --|> Linear
RowParallelLinear --|> Linear
2. 基础组件详解
2.1 ModelArgs 配置类
位置: model.py:17-90
1@dataclass
2class ModelArgs:
3 # 基础配置
4 max_batch_size: int = 8 # 最大批次大小
5 max_seq_len: int = 4096 * 4 # 最大序列长度 (16K,实际 V3.2 是 163K)
6 dtype: Literal["bf16", "fp8"] = "bf16" # 计算精度
7 scale_fmt: Optional[str] = None # FP8 激活量化 scale 的存储格式(如 ue8m0)
8 vocab_size: int = 102400 # 词表大小 (实际 V3.2 是 129280)
9 dim: int = 2048 # 隐藏层维度 (实际 V3.2 是 7168)
10 inter_dim: int = 10944 # Dense MLP 的中间维度(实际 V3.2 是 18432)
11 moe_inter_dim: int = 1408 # MoE expert 的中间维度(实际 V3.2 是 2048)
12 n_layers: int = 27 # 层数 (实际 V3.2 是 61)
13 n_dense_layers: int = 1 # Dense 层数 (实际 V3.2 是 3)
14 n_heads: int = 16 # 注意力头数 (实际 V3.2 是 128)
15
16 # MoE 配置
17 n_routed_experts: int = 64 # 路由专家数 (实际 V3.2 是 256)
18 n_shared_experts: int = 2 # 共享专家数 (实际 V3.2 是 1)
19 n_activated_experts: int = 6 # 激活专家数 (实际 V3.2 是 8)
20 n_expert_groups: int = 1 # 专家分组数 (实际 V3.2 是 8)
21 n_limited_groups: int = 1 # 选择的组数 (实际 V3.2 是 4)
22 score_func: str = "softmax" # 评分函数 (实际 V3.2 是 sigmoid)
23 route_scale: float = 1. # 路由缩放 (实际 V3.2 是 2.5)
24
25 # MLA 配置
26 q_lora_rank: int = 0 # Q 低秩维度 (实际 V3.2 是 1536)
27 kv_lora_rank: int = 512 # KV 低秩维度
28 qk_nope_head_dim: int = 128 # 无位置编码的 QK 维度
29 qk_rope_head_dim: int = 64 # 有位置编码的 QK 维度
30 v_head_dim: int = 128 # V 头维度
31
32 # YaRN 配置
33 original_seq_len: int = 4096 # 原始训练长度
34 rope_theta: float = 10000.0 # RoPE 基础频率
35 rope_factor: float = 40 # 长度外推因子
36 beta_fast: int = 32 # 快速衰减参数
37 beta_slow: int = 1 # 慢速衰减参数
38 mscale: float = 1. # YaRN 额外注意力缩放(影响 softmax_scale)
39
40 # Indexer 配置
41 index_n_heads: int = 64 # Indexer 头数
42 index_head_dim: int = 128 # Indexer 头维度
43 index_topk: int = 2048 # 选择的 token 数
配置流程图:
flowchart LR
subgraph "配置加载"
JSON[inference/config_671B_v3.2.json] --> Parse[json.load]
Parse --> Args[ModelArgs**kwargs]
Args --> Model[Transformer]
end
注意:仓库根目录的
config.json是 HuggingFace Transformers 的模型配置(字段名如hidden_size/num_hidden_layers/...),不能直接喂给ModelArgs(**json.load(...))。本推理 demo 使用的是inference/config_671B_v3.2.json(字段名与ModelArgs一一对应)。
2.2 ParallelEmbedding 并行嵌入层
位置: model.py:92-131
作用: 将词表按行切分到多个 GPU,实现词嵌入的张量并行。
2.2.1 初始化流程
flowchart TB
subgraph "ParallelEmbedding.__init__"
VS[vocab_size=129280] --> Check{vocab_size % world_size == 0?}
Check -->|Yes| Calc[part_vocab_size = 129280/8 = 16160]
Calc --> Range["vocab_start_idx = rank × 16160
vocab_end_idx = start + 16160"]
Range --> Weight["weight = Parameter(16160, 7168)"]
end
源码逐行解析:
1class ParallelEmbedding(nn.Module):
2 def __init__(self, vocab_size: int, dim: int):
3 super().__init__()
4 self.vocab_size = vocab_size # 完整词表大小: 129280
5 self.dim = dim # 嵌入维度: 7168
6
7 # 检查词表是否可被 GPU 数整除
8 assert vocab_size % world_size == 0
9
10 # 计算每个 GPU 负责的词表范围
11 self.part_vocab_size = vocab_size // world_size # 129280/8=16160
12 self.vocab_start_idx = rank * self.part_vocab_size # GPU0: 0, GPU1: 16160, ...
13 self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
14
15 # 只存储本 GPU 负责的词表部分
16 self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
2.2.2 关于 world_size 的说明
代码顶部的 world_size = 1 只是默认值,实际运行时会被动态修改:
1# model.py:13-14 (默认值)
2world_size = 1
3rank = 0
4
5# model.py:873-875 (运行时修改)
6class Transformer(nn.Module):
7 def __init__(self, args: ModelArgs):
8 global world_size, rank # 修改全局变量
9 world_size = dist.get_world_size() if dist.is_initialized() else 1
10 rank = dist.get_rank() if dist.is_initialized() else 0
| 运行场景 | world_size |
part_vocab_size |
|---|---|---|
| 单卡运行 | 1 | 129280 (完整词表) |
8卡运行 (torchrun --nproc_per_node=8) |
8 | 16160 |
2.2.3 前向传播通俗解释:分工合作查字典
想象词表就是一本 12 万词的大字典,太大了一个人查不过来,于是让 8 个 GPU 分工:
1GPU 0: 负责词 0 ~ 16159 (相当于 A-C 开头的词)
2GPU 1: 负责词 16160 ~ 32319 (相当于 D-F 开头的词)
3GPU 2: 负责词 32320 ~ 48479 (相当于 G-I 开头的词)
4...
5GPU 7: 负责词 113120 ~ 129279 (相当于 X-Z 开头的词)
具体例子:假设输入是 x = [100, 20000, 50000](3 个词的 ID)
Step 1: 每个 GPU 判断"这是不是我的活"
1mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
| GPU | 负责范围 | 词 100 | 词 20000 | 词 50000 | mask 结果 |
|---|---|---|---|---|---|
| GPU 0 | 0~16159 | ✓ 我的 | ✗ 不是 | ✗ 不是 | [False, True, True] |
| GPU 1 | 16160~32319 | ✗ 不是 | ✓ 我的 | ✗ 不是 | [True, False, True] |
| GPU 3 | 48480~64639 | ✗ 不是 | ✗ 不是 | ✓ 我的 | [True, True, False] |
Step 2: 转换成"本地索引"
1x = x - self.vocab_start_idx # 减去起始位置
2x[mask] = 0 # 不是我的活就填 0(防止数组越界)
| GPU | 原始 x | 减去起始偏移 | 不是我的填 0 |
|---|---|---|---|
| GPU 0 (起始=0) | [100, 20000, 50000] |
[100, 20000, 50000] |
[100, 0, 0] |
| GPU 1 (起始=16160) | [100, 20000, 50000] |
[-16060, 3840, 33840] |
[0, 3840, 0] |
| GPU 3 (起始=48480) | [100, 20000, 50000] |
[-48380, -28480, 1520] |
[0, 0, 1520] |
Step 3: 各自查表
1y = F.embedding(x, self.weight)
| GPU | 查表结果 |
|---|---|
| GPU 0 | [词100的向量, 词0的向量(垃圾), 词0的向量(垃圾)] |
| GPU 1 | [词0的向量(垃圾), 词20000的向量, 词0的向量(垃圾)] |
| GPU 3 | [词0的向量(垃圾), 词0的向量(垃圾), 词50000的向量] |
Step 4: 把"垃圾数据"清零
1y[mask] = 0 # 不是我负责的位置,结果清零
| GPU | 清零后的结果 |
|---|---|
| GPU 0 | [词100的向量, 0, 0] |
| GPU 1 | [0, 词20000的向量, 0] |
| GPU 3 | [0, 0, 词50000的向量] |
Step 5: 汇总所有 GPU 的结果
1dist.all_reduce(y) # 所有 GPU 的结果相加
1 [词100的向量, 0, 0 ] ← GPU 0 的贡献
2+ [0, 词20000的向量, 0 ] ← GPU 1 的贡献
3+ [0, 0, 词50000的向量] ← GPU 3 的贡献
4+ ... ← 其他 GPU 贡献 0
5─────────────────────────────────────────────────────────────────
6= [词100的向量, 词20000的向量, 词50000的向量] ← 最终结果
流程图总结:
flowchart TB
subgraph "输入"
X["x = [100, 20000, 50000]
3个词需要查询"]
end
subgraph "GPU 0 (负责词 0-16159)"
M0["Step1: 判断
mask = [F, T, T]
只有100是我的"]
L0["Step2: 本地索引
[100, 0, 0]"]
E0["Step3: 查表
[emb_100, 垃圾, 垃圾]"]
C0["Step4: 清零
[emb_100, 0, 0]"]
end
subgraph "GPU 1 (负责词 16160-32319)"
M1["Step1: 判断
mask = [T, F, T]
只有20000是我的"]
L1["Step2: 本地索引
[0, 3840, 0]"]
E1["Step3: 查表
[垃圾, emb_20000, 垃圾]"]
C1["Step4: 清零
[0, emb_20000, 0]"]
end
subgraph "GPU 3 (负责词 48480-64639)"
M3["Step1: 判断
mask = [T, T, F]
只有50000是我的"]
L3["Step2: 本地索引
[0, 0, 1520]"]
E3["Step3: 查表
[垃圾, 垃圾, emb_50000]"]
C3["Step4: 清零
[0, 0, emb_50000]"]
end
X --> M0 --> L0 --> E0 --> C0
X --> M1 --> L1 --> E1 --> C1
X --> M3 --> L3 --> E3 --> C3
C0 --> AR["Step5: all_reduce
所有结果相加"]
C1 --> AR
C3 --> AR
AR --> Y["最终结果
y = [emb_100, emb_20000, emb_50000]"]
一句话总结:每个 GPU 只查自己负责的词,查到了就返回向量,查不到就返回 0,最后把所有 GPU 的结果加起来。
2.2.4 源码逐行解析
1def forward(self, x: torch.Tensor) -> torch.Tensor:
2 if world_size > 1:
3 # 1. 创建 mask: 标记不属于本 GPU 的 token
4 mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
5
6 # 2. 转换为本地索引 (减去起始偏移)
7 x = x - self.vocab_start_idx
8
9 # 3. 不属于本 GPU 的 token 索引设为 0 (防止越界)
10 x[mask] = 0
11
12 # 4. 查表得到嵌入
13 y = F.embedding(x, self.weight)
14
15 if world_size > 1:
16 # 5. 不属于本 GPU 的结果清零
17 y[mask] = 0
18
19 # 6. 所有 GPU 结果求和 (all_reduce)
20 dist.all_reduce(y)
21
22 return y
2.2.5 显存对比
| 配置 | 词表显存 |
|---|---|
| 单卡完整词表 | 129280 × 7168 × 2B = 1.73 GB |
| 8卡并行 (每卡) | 16160 × 7168 × 2B = 216 MB |
2.3 Linear 及并行变体
位置: model.py:134-269
2.3.1 通俗解释:分工合作算矩阵
想象你要计算一个超大的矩阵乘法 y = x × W,其中 W 是一个 7168 × 18432 的巨型矩阵(约 2.6 亿个数字)。单卡显存放不下,怎么办?
两种切分策略:
策略一:列并行 (ColumnParallelLinear) - “每人算一部分结果”
1想象你要做一道数学题:计算 [1, 2, 3] × [[a, b, c, d],
2 [e, f, g, h],
3 [i, j, k, l]]
4
5结果是 [y1, y2, y3, y4]
6
7列并行就是让 2 个 GPU 分工:
8- GPU 0: 只算 [y1, y2] (对应矩阵的前 2 列)
9- GPU 1: 只算 [y3, y4] (对应矩阵的后 2 列)
flowchart LR
subgraph "列并行:输入相同,输出切分"
X["x: [B,S,7168]
完整输入"] --> GPU0["GPU0: W的前半列
[7168, 9216]"]
X --> GPU1["GPU1: W的后半列
[7168, 9216]"]
GPU0 --> Y0["y0: [B,S,9216]"]
GPU1 --> Y1["y1: [B,S,9216]"]
end
特点:
- ✅ 输入不需要通信(每个 GPU 都有完整的 x)
- ✅ 输出直接使用(不需要 all_reduce)
- ⚠️ 输出是"切片",后续层要知道怎么用
策略二:行并行 (RowParallelLinear) - “每人算一部分加数”
1想象你要做一道数学题:计算 [1, 2, 3, 4] × [[a],
2 [b],
3 [c],
4 [d]]
5
6结果是 y = 1×a + 2×b + 3×c + 4×d
7
8行并行就是让 2 个 GPU 分工:
9- GPU 0: 计算 1×a + 2×b (输入的前半部分 × 权重的前半行)
10- GPU 1: 计算 3×c + 4×d (输入的后半部分 × 权重的后半行)
11- 最后:all_reduce 求和得到 y
flowchart LR
subgraph "行并行:输入切分,输出求和"
X0["x0: [B,S,3584]
输入前半"] --> GPU0["GPU0: W的前半行
[3584, 7168]"]
X1["x1: [B,S,3584]
输入后半"] --> GPU1["GPU1: W的后半行
[3584, 7168]"]
GPU0 --> Y0["y0: [B,S,7168]
(部分和)"]
GPU1 --> Y1["y1: [B,S,7168]
(部分和)"]
Y0 --> AR["all_reduce
求和"]
Y1 --> AR
AR --> Y["y: [B,S,7168]
完整结果"]
end
特点:
- ⚠️ 输入必须是切片(通常来自上一层的列并行)
- ⚠️ 输出需要 all_reduce 通信
- ✅ 结果是完整的,可以直接用
组合使用的例子:MLP 层
1MLP 有 3 个矩阵:W1, W2, W3
2
3设计思路:
4- W1 和 W3 用列并行 → 输出是切片
5- W2 用行并行 → 输入正好是切片,输出是完整结果
6
7这样刚好"接龙",只需要在 W2 后做一次 all_reduce!
flowchart LR
X["x: 完整"] --> W1["W1: 列并行"]
X --> W3["W3: 列并行"]
W1 --> SILU["SiLU"]
SILU --> MUL["×"]
W3 --> MUL
MUL --> W2["W2: 行并行
(含 all_reduce)"]
W2 --> Y["y: 完整"]
2.3.2 linear 函数详解
1def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None,
2 scale_fmt: Optional[str] = None) -> torch.Tensor:
3 assert bias is None # 不支持 bias
4
5 if weight.dtype != torch.float8_e4m3fn:
6 # BF16 权重: 使用标准 F.linear
7 return F.linear(x, weight)
8 else:
9 # FP8 权重: 量化输入后使用 FP8 GEMM
10 x, scale = act_quant(x, block_size, scale_fmt) # 量化激活
11 return fp8_gemm(x, scale, weight, weight.scale) # FP8 矩阵乘
通俗解释:
- BF16 模式:直接调用 PyTorch 的标准矩阵乘法
- FP8 模式:
- 先把输入 x 从 BF16 “压缩"成 FP8(节省显存和带宽)
- 用专门的 FP8 矩阵乘法 Kernel(Tensor Core 加速)
- 结果自动转回 BF16
FP8 vs BF16 计算流程:
flowchart TB
subgraph "BF16 模式(简单直接)"
X1[x: BF16] --> Linear1[F.linear]
W1[weight: BF16] --> Linear1
Linear1 --> Y1[y: BF16]
end
subgraph "FP8 模式(压缩后计算)"
X2[x: BF16] --> Quant[act_quant
量化]
Quant --> X_FP8[x: FP8
1字节/数]
Quant --> X_Scale[x_scale: FP32
缩放因子]
X_FP8 --> GEMM[fp8_gemm
FP8矩阵乘]
X_Scale --> GEMM
W2[weight: FP8] --> GEMM
W_Scale[weight.scale: FP32] --> GEMM
GEMM --> Y2[y: BF16]
end
2.3.3 Linear 类
1class Linear(nn.Module):
2 dtype = torch.bfloat16 # 类级别默认精度
3 scale_fmt: Optional[str] = None
4
5 def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype=None):
6 super().__init__()
7 self.in_features = in_features
8 self.out_features = out_features
9
10 # 创建权重参数
11 self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype))
12
13 # FP8 模式下需要额外的 scale 参数
14 if self.weight.element_size() == 1: # FP8 每个元素 1 字节
15 scale_out = (out_features + block_size - 1) // block_size # 向上取整
16 scale_in = (in_features + block_size - 1) // block_size
17 # scale 按 128×128 块存储
18 self.weight.scale = self.scale = nn.Parameter(
19 torch.empty(scale_out, scale_in, dtype=torch.float32)
20 )
21 else:
22 self.register_parameter("scale", None)
FP8 权重与 Scale 的关系:
flowchart TB
subgraph "权重矩阵 (7168 × 1536)"
W[weight: FP8
shape: 7168×1536]
end
subgraph "Scale 矩阵 (56 × 12)"
S["scale: FP32
shape: ceil(7168/128) × ceil(1536/128)
= 56 × 12"]
end
subgraph "对应关系"
B1["Block (0,0)
128×128"] --> S1["scale[0,0]"]
B2["Block (0,1)
128×128"] --> S2["scale[0,1]"]
BN["Block (55,11)
128×128"] --> SN["scale[55,11]"]
end
2.3.3 ColumnParallelLinear (列并行)
作用: 将输出维度切分到多个 GPU。
flowchart LR
subgraph "ColumnParallelLinear"
X[x: B×S×D] --> GPU0["GPU0: W[0:N/8, :]"]
X --> GPU1["GPU1: W[N/8:2N/8, :]"]
X --> GPU7["GPU7: W[7N/8:N, :]"]
GPU0 --> Y0["y0: B×S×N/8"]
GPU1 --> Y1["y1: B×S×N/8"]
GPU7 --> Y7["y7: B×S×N/8"]
end
1class ColumnParallelLinear(Linear):
2 def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype=None):
3 assert out_features % world_size == 0
4 self.part_out_features = out_features // world_size # 输出维度切分
5 super().__init__(in_features, self.part_out_features, bias, dtype)
6
7 def forward(self, x: torch.Tensor) -> torch.Tensor:
8 # 直接计算,输出是完整输出的一部分
9 return linear(x, self.weight, self.bias, self.scale_fmt)
2.3.4 RowParallelLinear (行并行)
作用: 将输入维度切分到多个 GPU,计算后需要 all_reduce 合并。
flowchart LR
subgraph "RowParallelLinear"
X0["x0: B×S×D/8"] --> GPU0["GPU0: W[:, 0:D/8]"]
X1["x1: B×S×D/8"] --> GPU1["GPU1: W[:, D/8:2D/8]"]
X7["x7: B×S×D/8"] --> GPU7["GPU7: W[:, 7D/8:D]"]
GPU0 --> Y0["y0: B×S×N"]
GPU1 --> Y1["y1: B×S×N"]
GPU7 --> Y7["y7: B×S×N"]
Y0 --> AR[all_reduce]
Y1 --> AR
Y7 --> AR
AR --> Y["y: B×S×N"]
end
1class RowParallelLinear(Linear):
2 def __init__(self, in_features: int, out_features: int, bias: bool = False,
3 reduce_output=True, dtype=None):
4 assert in_features % world_size == 0
5 self.part_in_features = in_features // world_size # 输入维度切分
6 self.reduce_output = reduce_output
7 super().__init__(self.part_in_features, out_features, bias, dtype)
8
9 def forward(self, x: torch.Tensor) -> torch.Tensor:
10 y = linear(x, self.weight, None, self.scale_fmt)
11
12 if self.reduce_output and world_size > 1:
13 y = y.float()
14 dist.all_reduce(y) # 关键: 合并所有 GPU 的结果
15
16 if self.bias is not None:
17 y += self.bias
18
19 return y.type_as(x)
2.3.5 并行策略对比
| 类型 | 切分维度 | 通信操作 | 适用场景 |
|---|---|---|---|
| ColumnParallel | 输出 (dim=0) | 无 (输出直接使用) | 后接 RowParallel |
| RowParallel | 输入 (dim=1) | all_reduce | 后接 ColumnParallel 或输出 |
2.4 RMSNorm 与 LayerNorm
位置: model.py:272-321
2.4.1 通俗解释:给向量"校准刻度”
为什么需要归一化?
想象你有一组测量数据:[0.001, 0.002, 0.003] 和另一组 [1000, 2000, 3000]。虽然它们的"相对关系"一样,但数值差了百万倍!
神经网络在训练/推理时,不同层输出的数值范围可能差异很大。如果不处理:
- 有些层输出
0.001,有些输出10000 - 后面的计算会被大数值"淹没"
- 梯度可能爆炸或消失
归一化的作用:把每个向量"校准"到统一的刻度,通常让数值在 -1 到 1 附近。
RMSNorm vs LayerNorm:两种"校准"方式
| 方法 | LayerNorm | RMSNorm |
|---|---|---|
| 步骤 | 1. 减去均值 2. 除以标准差 | 只除以 RMS(均方根) |
| 公式 | $\frac{x - \mu}{\sigma}$ | $\frac{x}{\text{RMS}(x)}$ |
| 参数 | weight + bias | 只有 weight |
| 计算量 | 较多 | 较少 |
| 效果 | 更"中心化" | 保持原始分布形状 |
通俗比喻:
- LayerNorm = “把考试成绩换算成标准分”(减均值、除标准差)
- RMSNorm = “把成绩按总分折算成百分比”(只做缩放,不平移)
具体例子:假设输入向量 x = [3, 4]
1RMSNorm 计算过程:
2
3Step 1: 计算平方
4 x² = [9, 16]
5
6Step 2: 计算均方 (Mean of Squares)
7 mean(x²) = (9 + 16) / 2 = 12.5
8
9Step 3: 计算均方根 (Root Mean Square)
10 RMS = √12.5 ≈ 3.54
11
12Step 4: 归一化
13 y = x / RMS = [3/3.54, 4/3.54] ≈ [0.85, 1.13]
14
15Step 5: 乘以可学习权重 γ
16 y = y × γ = [0.85×γ₁, 1.13×γ₂]
flowchart LR
X["x = [3, 4]"] --> Square["x² = [9, 16]"]
Square --> Mean["mean = 12.5"]
Mean --> Add["+ ε = 12.500001"]
Add --> Rsqrt["1/√... ≈ 0.283"]
X --> Mul1["×"]
Rsqrt --> Mul1
Mul1 --> Result["[0.85, 1.13]"]
Result --> Mul2["× γ"]
Mul2 --> Y["最终输出"]
2.4.2 RMSNorm 公式与源码
公式: $$\text{RMSNorm}(x) = \frac{x}{\sqrt{\text{mean}(x^2) + \epsilon}} \cdot \gamma$$
源码解析:
1class RMSNorm(nn.Module):
2 def __init__(self, dim: int, eps: float = 1e-6):
3 super().__init__()
4 self.dim = dim
5 self.eps = eps
6 # 注意: weight 是 FP32 精度
7 self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
8
9 def forward(self, x: torch.Tensor, residual: Optional[torch.Tensor] = None):
10 dtype = x.dtype
11
12 if residual is None:
13 # 模式1: 普通归一化
14 x = x.float()
15 var = x.pow(2).mean(-1, keepdim=True) # 计算方差 (实际是均方)
16 x = x * torch.rsqrt(var + self.eps) # 归一化
17 return (self.weight * x).to(dtype)
18 else:
19 # 模式2: 融合残差的归一化 (优化显存)
20 x = residual = x.float() + residual.float() # 先加残差
21 var = x.pow(2).mean(-1, keepdim=True)
22 x = x * torch.rsqrt(var + self.eps)
23 return (self.weight * x).to(dtype), residual.to(dtype)
2.4.2 LayerNorm
1class LayerNorm(nn.Module):
2 def __init__(self, dim: int, eps: float = 1e-6):
3 super().__init__()
4 self.dim = dim
5 self.eps = eps
6 self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
7 self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) # 比 RMSNorm 多一个 bias
8
9 def forward(self, x: torch.Tensor):
10 return F.layer_norm(x.float(), (self.dim,), self.weight, self.bias, self.eps).type_as(x)
2.4.3 RMSNorm vs LayerNorm
| 特性 | RMSNorm | LayerNorm |
|---|---|---|
| 公式 | $\frac{x}{\text{RMS}(x)} \cdot \gamma$ | $\frac{x - \mu}{\sigma} \cdot \gamma + \beta$ |
| 参数 | weight | weight + bias |
| 计算量 | 更少 (无需计算均值) | 更多 |
| 使用场景 | 主干网络 | Indexer 的 k_norm |
3. 位置编码详解
3.0 通俗解释:让模型知道"谁在前谁在后"
为什么需要位置编码?
看这两句话:
- “狗咬了人”
- “人咬了狗”
词完全一样,但意思完全不同!区别在于词的顺序。
Transformer 的注意力机制本身是"位置无关"的——它只看词与词之间的关系,不知道谁在前谁在后。所以我们需要给每个位置一个"身份标签"。
传统方法 vs RoPE:两种"标签"方式
| 方法 | 传统位置编码 | RoPE (旋转位置编码) |
|---|---|---|
| 思路 | 给每个位置加一个固定向量 | 根据位置旋转向量 |
| 类比 | 给每个学生发一个号码牌 | 让每个学生转不同的角度 |
| 优点 | 简单直观 | 可以表示相对位置,支持长度外推 |
RoPE 的核心思想:用"旋转"表示位置
想象一个时钟:
- 位置 0 的向量:指向 12 点方向
- 位置 1 的向量:顺时针转一点
- 位置 2 的向量:再转一点
- …
1位置 0: ↑ 位置 1: ↗ 位置 2: → 位置 3: ↘
2 | / — \
关键特性:两个位置的"相对距离"可以通过旋转角度的差来表示!
比如位置 5 和位置 3 的相对距离是 2,无论它们在序列的哪里,旋转角度差都一样。
为什么用复数?
复数乘法天然就是旋转!
1复数 z = a + bi 可以表示为 (长度, 角度)
2复数乘法 z1 × z2 = (长度1×长度2, 角度1+角度2)
所以 向量 × e^(iθ) 就是把向量旋转 θ 角度。
具体例子:假设向量 x = [1, 0],位置 m=3,频率 θ=30°
1Step 1: 把向量看作复数
2 x = 1 + 0i
3
4Step 2: 计算旋转角度
5 角度 = m × θ = 3 × 30° = 90°
6
7Step 3: 计算旋转因子
8 e^(i×90°) = cos(90°) + i×sin(90°) = 0 + 1i
9
10Step 4: 复数乘法
11 x' = (1 + 0i) × (0 + 1i) = 0 + 1i
12
13Step 5: 转回向量
14 x' = [0, 1] (从指向右变成指向上,旋转了 90°)
flowchart LR
subgraph "RoPE 旋转示意"
X["原始向量 [1,0]
→"] --> Rotate["× e^(i×90°)"]
Rotate --> Y["旋转后 [0,1]
↑"]
end
3.1 precompute_freqs_cis 预计算
位置: model.py:324-402
作用: 预计算所有位置的旋转频率复数表示。
为什么要"预计算"?
- 每个位置的旋转角度是固定的(只跟位置有关)
- 提前算好存起来,推理时直接查表,省时间
flowchart TB
subgraph "频率计算"
Base["base = 10000"] --> Freqs["freqs = 1/(base^(2i/d))"]
Freqs --> Check{"seq_len > original?"}
Check -->|Yes| YaRN["YaRN 调整频率"]
Check -->|No| Direct["直接使用"]
end
subgraph "复数表示"
T["positions = [0,1,2,...,seq_len-1]"] --> Outer["outer(positions, freqs)"]
Outer --> Polar["polar(1, angles)"]
Polar --> FreqsCis["freqs_cis: 复数 e^(iθ)"]
end
源码逐行解析:
1def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
2 dim = args.qk_rope_head_dim # 64
3 seqlen = args.max_seq_len # 163840
4 base = args.rope_theta # 10000
5 factor = args.rope_factor # 40
6
7 # 计算基础频率: θ_i = 1 / (10000^(2i/64))
8 # i ∈ [0, 2, 4, ..., 62], 共 32 个
9 freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
10 # freqs.shape = [32], 值从 1.0 衰减到 约 0.01
11
12 # 如果需要外推 (163840 > 4096)
13 if seqlen > args.original_seq_len:
14 # YaRN 调整 (详见 3.3 节)
15 low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len)
16 smooth = 1 - linear_ramp_factor(low, high, dim // 2)
17 freqs = freqs / factor * (1 - smooth) + freqs * smooth
18
19 # 计算每个位置的角度
20 t = torch.arange(seqlen) # [0, 1, 2, ..., 163839]
21 freqs = torch.outer(t, freqs) # [163840, 32] - 外积
22
23 # 转换为复数表示: e^(iθ) = cos(θ) + i*sin(θ)
24 freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
25 # freqs_cis.shape = [163840, 32], dtype = complex64
26
27 return freqs_cis
3.2 apply_rotary_emb 应用旋转
位置: model.py:405-425
RoPE 原理:
将向量看作复数,通过乘以 $e^{i m \theta}$ 实现位置相关的旋转:
$$ \text{RoPE}(x, m) = x \cdot e^{im\theta} $$
其中 $m$ 是位置,$\theta$ 是频率。
flowchart TB
subgraph "apply_rotary_emb"
X["x: [B,S,H,D]
D=64 (qk_rope_head_dim)"] --> View["view_as_complex
[B,S,H,32] complex"]
FreqsCis["freqs_cis: [S,32] complex"] --> Broadcast["broadcast to [1,S,1,32]"]
View --> Mul["x × freqs_cis"]
Broadcast --> Mul
Mul --> Real["view_as_real
[B,S,H,32,2]"]
Real --> Flatten["flatten → [B,S,H,64]"]
end
源码逐行解析:
1def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor, interleaved: bool = True) -> torch.Tensor:
2 dtype = x.dtype
3 shape = x.shape # [B, S, H, D]
4
5 if not interleaved:
6 # 非交错模式: [x0,x1,x2,...,x31,x32,...,x63] → [x0,x32,x1,x33,...,x31,x63]
7 x = x.view(*shape[:-1], 2, -1).transpose(-1, -2).contiguous()
8
9 # 将实数对看作复数: [B,S,H,D] → [B,S,H,D/2] complex
10 x = torch.view_as_complex(x.float().view(*shape[:-1], -1, 2))
11
12 # 广播 freqs_cis 到 [1, S, 1, D/2]
13 freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
14
15 # 复数乘法实现旋转
16 y = torch.view_as_real(x * freqs_cis).flatten(3) # [B,S,H,D]
17
18 if not interleaved:
19 # 恢复非交错布局
20 y = torch.cat([y[..., 0::2], y[..., 1::2]], dim=-1)
21
22 return y.to(dtype)
3.3 YaRN 长度外推
问题: 模型在 4096 长度上训练,如何外推到 163840?
解决方案: YaRN (Yet another RoPE extensioN) 对不同频率的维度采用不同的缩放策略。
flowchart TB
subgraph "YaRN 策略"
Low["低频维度 (i < low)
线性插值: freqs / factor"]
High["高频维度 (i > high)
保持不变: freqs"]
Mid["中间维度 (low ≤ i ≤ high)
平滑过渡"]
end
subgraph "参数含义"
BetaFast["beta_fast=32
高频旋转次数阈值"]
BetaSlow["beta_slow=1
低频旋转次数阈值"]
Factor["factor=40
外推倍数"]
end
源码中的辅助函数:
1def find_correction_dim(num_rotations, dim, base, max_seq_len):
2 """计算给定旋转次数对应的维度索引"""
3 return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
4
5def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
6 """计算需要调整的维度范围"""
7 low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
8 high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
9 return max(low, 0), min(high, dim-1)
10
11def linear_ramp_factor(min, max, dim):
12 """生成平滑过渡的线性斜坡"""
13 linear_func = (torch.arange(dim) - min) / (max - min)
14 return torch.clamp(linear_func, 0, 1)
3.3.1 mscale:额外注意力缩放(实现细节)
除 RoPE 频率缩放外,MLA 里还会在“外推场景”额外调节 softmax 的缩放系数(只在 max_seq_len > original_seq_len 时生效):
1if args.max_seq_len > args.original_seq_len:
2 mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
3 self.softmax_scale = self.softmax_scale * mscale * mscale
这意味着在长上下文外推时,注意力的温度会随 rope_factor/mscale 一起被重新校准;对应参数在 inference/config_671B_v3.2.json 中提供。
4. MLA 注意力机制
4.0 通俗解释:压缩记忆的艺术
问题:KV Cache 太大了!
传统注意力机制需要缓存所有历史 token 的 K 和 V 向量。对于 DeepSeek-V3.2:
1传统 KV Cache 大小计算:
2- 序列长度: 163840 tokens
3- K 维度: 128 头 × 192 维 = 24576
4- V 维度: 128 头 × 128 维 = 16384
5- 每层缓存: 163840 × (24576 + 16384) × 2B ≈ 13.4 GB
6- 61 层总共: 13.4 × 61 ≈ 817 GB 😱
这显然不可接受!
MLA 的核心思想:先压缩再缓存
想象你要记笔记:
- 传统方法:把老师说的每句话都逐字记下来 → 笔记本爆炸
- MLA 方法:先把老师的话"压缩"成关键词,只记关键词 → 需要时再"展开"
1传统:x → K (24576维) → 缓存 K
2MLA: x → c_KV (512维) → 缓存 c_KV → 展开成 K (24576维)
3
4压缩率: 512 / 24576 ≈ 2%!
MLA Cache 大小计算:
1MLA KV Cache 大小:
2- 序列长度: 163840 tokens
3- 压缩后维度: 512 + 64 (RoPE) = 576
4- 每层缓存: 163840 × 576 × 2B ≈ 188 MB
5- 61 层总共: 188 × 61 ≈ 11.5 GB ✅
6
7节省: 817 GB → 11.5 GB,减少 98.6%!
MLA 的三个关键概念
| 概念 | 作用 | 类比 |
|---|---|---|
| 低秩压缩 | 用小矩阵近似大矩阵 | 用"摘要"代替"全文" |
| nope + rope 分离 | 把向量分成两部分处理 | 有些信息需要位置,有些不需要 |
| 延迟展开 | 缓存压缩版,用时再展开 | 压缩包存硬盘,用时解压 |
Q/K/V 的压缩与展开流程
flowchart TB
subgraph "Query 路径"
X1["x: 7168维"] --> WQA["wq_a: 压缩
7168→1536"]
WQA --> QNorm["RMSNorm"]
QNorm --> WQB["wq_b: 展开
1536→24576"]
WQB --> Q["Q: 128头×192维"]
end
subgraph "KV 路径 (关键!)"
X2["x: 7168维"] --> WKVA["wkv_a: 压缩
7168→576"]
WKVA --> Split["分离"]
Split --> CKV["c_kv: 512维
(缓存这个!)"]
Split --> KPE["k_pe: 64维
(位置编码)"]
CKV --> KVNorm["RMSNorm"]
KVNorm --> Cache["💾 存入缓存"]
KVNorm --> WKVB["wkv_b: 展开
512→32768"]
WKVB --> KV["K_nope + V"]
end
为什么叫"Multi-head Latent Attention"?
- Multi-head: 多头注意力,每个头看不同的"视角"
- Latent: “潜在的”,指那个压缩后的小向量 c_KV
- Attention: 注意力机制
一句话总结:MLA 把 KV 压缩成一个"潜在向量"缓存起来,用的时候再展开,节省 98% 的显存。
4.1 MLA 类结构
位置: model.py:498-608
MLA vs 传统 MHA 对比:
graph LR
subgraph "传统 MHA"
X1[x] --> WQ["W_Q: D→H×D_h"]
X1 --> WK["W_K: D→H×D_h"]
X1 --> WV["W_V: D→H×D_h"]
WQ --> Q1[Q]
WK --> K1["K (缓存)"]
WV --> V1["V (缓存)"]
end
subgraph "MLA"
X2[x] --> WQA["W_QA: D→R_q"]
WQA --> QNorm[RMSNorm]
QNorm --> WQB["W_QB: R_q→H×D_h"]
WQB --> Q2[Q]
X2 --> WKVA["W_KVA: D→R_kv+D_rope"]
WKVA --> KVNorm[RMSNorm]
KVNorm --> C_KV["c_KV (只缓存这个!)"]
C_KV --> WKVB["W_KVB: R_kv→H×(D_nope+D_v)"]
WKVB --> K2[K]
WKVB --> V2[V]
end
类初始化:
1class MLA(nn.Module):
2 def __init__(self, args: ModelArgs):
3 super().__init__()
4 self.dim = args.dim # 7168
5 self.n_heads = args.n_heads # 128
6 self.n_local_heads = args.n_heads // world_size # 128/8=16 (每 GPU)
7 self.q_lora_rank = args.q_lora_rank # 1536
8 self.kv_lora_rank = args.kv_lora_rank # 512
9 self.qk_nope_head_dim = args.qk_nope_head_dim # 128
10 self.qk_rope_head_dim = args.qk_rope_head_dim # 64
11 self.qk_head_dim = 128 + 64 # 192
12 self.v_head_dim = args.v_head_dim # 128
13
14 # === Query 路径 ===
15 self.wq_a = Linear(7168, 1536) # 压缩
16 self.q_norm = RMSNorm(1536)
17 self.wq_b = ColumnParallelLinear(1536, 128*192) # 展开 (列并行)
18
19 # === KV 路径 ===
20 self.wkv_a = Linear(7168, 512+64) # 压缩 (512 for KV, 64 for RoPE)
21 self.kv_norm = RMSNorm(512)
22 self.wkv_b = ColumnParallelLinear(512, 128*(128+128)) # 展开 (k_nope + v)
23
24 # === 输出 ===
25 self.wo = RowParallelLinear(128*128, 7168) # 行并行
26
27 # === Indexer ===
28 self.indexer = Indexer(args)
29
30 # === KV Cache ===
31 self.register_buffer("kv_cache", torch.zeros(8, 163840, 512)) # 只缓存压缩的 c_KV
32 self.register_buffer("pe_cache", torch.zeros(8, 163840, 64)) # 缓存 k_pe
4.2 Query 路径详解
flowchart TB
subgraph "Query 计算"
X["x: [B,S,7168]"] --> WQA["wq_a
Linear(7168→1536)"]
WQA --> C_Q["c_q: [B,S,1536]"]
C_Q --> QNorm["q_norm
RMSNorm"]
QNorm --> QR["qr: [B,S,1536]
(也传给 Indexer)"]
QR --> WQB["wq_b
ColumnParallel(1536→128×192/8)"]
WQB --> Q["q: [B,S,16,192]
(每 GPU 16 头)"]
Q --> Split["split(128, 64)"]
Split --> Q_NOPE["q_nope: [B,S,16,128]"]
Split --> Q_PE["q_pe: [B,S,16,64]"]
Q_PE --> RoPE["apply_rotary_emb"]
RoPE --> Q_PE_ROT["q_pe (旋转后)"]
end
源码:
1# 在 forward() 中
2qr = self.q_norm(self.wq_a(x)) # [B,S,1536] - 压缩 + 归一化
3q = self.wq_b(qr) # [B,S,16*192] - 展开 (列并行后每 GPU 16 头)
4q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim) # [B,S,16,192]
5
6# 分离 nope 和 rope 部分
7q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
8# q_nope: [B,S,16,128], q_pe: [B,S,16,64]
9
10q_pe = apply_rotary_emb(q_pe, freqs_cis) # 应用位置编码
4.3 KV 路径详解
flowchart TB
subgraph "KV 计算"
X["x: [B,S,7168]"] --> WKVA["wkv_a
Linear(7168→576)"]
WKVA --> KV_PE["[B,S,576]"]
KV_PE --> Split1["split(512, 64)"]
Split1 --> KV["kv: [B,S,512]"]
Split1 --> K_PE["k_pe: [B,S,64]"]
KV --> KVNorm["kv_norm
RMSNorm"]
KVNorm --> KV_Norm["kv (归一化)"]
KV_Norm --> FP8["FP8 量化模拟"]
FP8 --> Cache1["存入 kv_cache"]
K_PE --> RoPE2["apply_rotary_emb"]
RoPE2 --> K_PE_ROT["k_pe (旋转后)"]
K_PE_ROT --> Cache2["存入 pe_cache"]
KV_Norm --> WKVB["wkv_b
ColumnParallel(512→128×256/8)"]
WKVB --> KV_Expand["[B,S,16,256]"]
KV_Expand --> Split2["split(128, 128)"]
Split2 --> K_NOPE["k_nope: [B,S,16,128]"]
Split2 --> V["v: [B,S,16,128]"]
end
源码:
1# 在 forward() 中
2kv = self.wkv_a(x) # [B,S,576] - 压缩
3kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
4# kv: [B,S,512], k_pe: [B,S,64]
5
6kv = self.kv_norm(kv) # 归一化
7k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis) # 位置编码
8
9# FP8 量化模拟 (为了与部署一致)
10kv_fp8, kv_scale = act_quant(kv, block_size, self.scale_fmt)
11kv = (kv_fp8.view(-1, block_size).float() * kv_scale.view(-1, 1)).to(kv.dtype).view_as(kv)
12
13# 存入缓存
14self.kv_cache[:bsz, start_pos:end_pos] = kv # 只存 512 维!
15self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
4.4 Prefill vs Decode 模式
4.4.1 Prefill 模式 (处理整个 prompt)
条件: mask is not None (seqlen > 1)
flowchart TB
subgraph "Prefill 模式"
Q["Q: [B,S,16,192]"] --> QK["Q @ K^T"]
K["K: [B,S,16,192]"] --> QK
QK --> Scale["× softmax_scale"]
Scale --> Indexer["+ Indexer Mask"]
Indexer --> CausalMask["+ Causal Mask"]
CausalMask --> Softmax["softmax"]
Softmax --> AV["@ V"]
V["V: [B,S,16,128]"] --> AV
AV --> Output["[B,S,16,128]"]
end
源码:
1if mask is not None: # Prefill
2 q = torch.cat([q_nope, q_pe], dim=-1) # [B,S,16,192]
3
4 # 展开 KV
5 kv = self.wkv_b(kv) # [B,S,16*256]
6 kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
7 k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
8 k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
9
10 # 注意力计算
11 scores = torch.einsum("bshd,bthd->bsht", q, k).mul_(self.softmax_scale)
12
13 # Indexer 稀疏化
14 topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
15 index_mask = torch.full((bsz, seqlen, seqlen), float("-inf"), device=x.device)
16 index_mask = index_mask.scatter_(-1, topk_indices, 0) # 只保留 top-k
17 scores += index_mask.unsqueeze(2) + mask
18
19 scores = scores.softmax(dim=-1)
20 x = torch.einsum("bsht,bthd->bshd", scores, v)
4.4.2 Decode 模式 (逐 token 生成)
条件: mask is None (seqlen == 1)
关键优化: 不展开 K,而是先用 Q_nope 与 W_kv_b 合并,再与缓存的 c_KV 计算。
flowchart TB
subgraph "Decode 模式"
Q_NOPE["q_nope: [B,1,16,128]"] --> EINSUM1["einsum(bshd,hdc→bshc)"]
WKV_B["wkv_b[:, :128, :]: [16,128,512]"] --> EINSUM1
EINSUM1 --> Q_NOPE_PROJ["[B,1,16,512]"]
Q_NOPE_PROJ --> EINSUM2["einsum(bshc,btc→bsht)"]
KV_CACHE["kv_cache: [B,T,512]"] --> EINSUM2
EINSUM2 --> Score1["score1"]
Q_PE["q_pe: [B,1,16,64]"] --> EINSUM3["einsum(bshr,btr→bsht)"]
PE_CACHE["pe_cache: [B,T,64]"] --> EINSUM3
EINSUM3 --> Score2["score2"]
Score1 --> Add["+ × scale"]
Score2 --> Add
Add --> IndexMask["+ Indexer Mask"]
IndexMask --> Softmax["softmax"]
Softmax --> EINSUM4["einsum(bsht,btc→bshc)"]
KV_CACHE --> EINSUM4
EINSUM4 --> Out1["[B,1,16,512]"]
Out1 --> EINSUM5["einsum(bshc,hdc→bshd)"]
WKV_B2["wkv_b[:, -128:, :]: [16,128,512]"] --> EINSUM5
EINSUM5 --> Output["[B,1,16,128]"]
end
源码:
1else: # Decode
2 # 延迟反量化 wkv_b (只做一次)
3 if self.dequant_wkv_b is None and self.wkv_b.scale is not None:
4 self.dequant_wkv_b = weight_dequant(self.wkv_b.weight, self.wkv_b.scale)
5 wkv_b = self.wkv_b.weight if self.dequant_wkv_b is None else self.dequant_wkv_b
6 wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank) # [16, 256, 512]
7
8 # 关键优化: Q_nope 先与 W_kv_b 的 K 部分结合
9 q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
10 # [B,1,16,128] @ [16,128,512] → [B,1,16,512]
11
12 # 与缓存计算注意力分数
13 scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
14 torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
15
16 # Indexer 稀疏化
17 topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
18 index_mask = torch.full((bsz, 1, end_pos), float("-inf"), device=x.device)
19 index_mask = index_mask.scatter_(-1, topk_indices, 0)
20 scores += index_mask.unsqueeze(2)
21
22 scores = scores.softmax(dim=-1)
23
24 # 计算输出
25 x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
26 x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
5. Indexer 稀疏注意力
5.0 通俗解释:精选"关键词"的搜索引擎
问题:注意力太"民主"了!
传统注意力机制会让每个 token 关注所有历史 token。对于 163840 长度的上下文:
1传统注意力计算量:
2- 每个新 token 要和 163840 个历史 token 计算相似度
3- 计算量 = 163840 × 头数 × 维度 ≈ 非常大!
但实际上,大部分历史 token 跟当前问题根本没关系!
Indexer 的核心思想:先"粗筛"再"精选"
想象你在图书馆找资料:
- 传统方法:把图书馆所有书都翻一遍 → 太慢!
- Indexer 方法:
- 先用"关键词"快速搜索(粗筛)
- 只拿出最相关的 2048 本书(精选)
- 仔细阅读这 2048 本(精细注意力)
flowchart LR
subgraph "Indexer 工作流程"
All["163840 个历史 token"] --> Coarse["粗筛:快速评分"]
Coarse --> TopK["选出 top-2048"]
TopK --> Fine["精细注意力"]
Fine --> Output["最终结果"]
end
为什么选 2048?
| 选择数量 | 效果 | 计算量 |
|---|---|---|
| 全部 (163840) | 最准确 | 100% |
| 2048 | 几乎不损失 | 1.25% |
| 512 | 有损失 | 0.3% |
实验发现 2048 是个甜蜜点:计算量降到 1.25%,效果几乎不变!
Indexer 的评分公式
1对于当前 token 的 query q 和历史 token 的 key k:
2
3index_score = Σ (ReLU(q·k) × weight)
4
5- q·k: 相似度分数
6- ReLU: 只保留正分数(负相关的直接忽略)
7- weight: 每个注意力头的权重
8- Σ: 对所有头求和
直觉理解:
- 正分数 = “这本书可能有用” → 保留
- 负/零分数 = “这本书肯定没用” → 忽略
- weight = “这个头的判断有多可靠”
Hadamard 变换:让量化更准确
Indexer 使用 FP8 量化来加速计算,但 FP8 对数值分布有要求。Hadamard 变换就是一个"搅拌"操作:
1原始向量: [1, 0, 0, 0, ...] ← 很多 0,不利于量化
2变换后: [0.1, 0.1, 0.1, 0.1, ...] ← 值更均匀,量化更准
flowchart LR
X["原始向量
[1,0,0,0]"] --> H["Hadamard
变换"]
H --> Y["均匀化向量
[0.5,0.5,0.5,0.5]"]
Y --> Q["FP8 量化"]
Q --> GEMM["快速矩阵乘"]
与 MLA 的配合
flowchart TB
subgraph "MLA 主路径"
X["输入 x"] --> MLA_Q["生成 Q"]
X --> MLA_KV["生成 KV Cache"]
end
subgraph "Indexer 辅助路径"
X --> IDX_K["生成 Index K"]
MLA_Q --> QR["qr (压缩的Q)"]
QR --> IDX_Q["生成 Index Q"]
IDX_Q --> Score["计算 Index Score"]
IDX_K --> Score
Score --> TopK["选出 top-2048"]
end
TopK --> Mask["生成稀疏 Mask"]
MLA_KV --> Attn["稀疏注意力"]
Mask --> Attn
Attn --> Output["输出"]
一句话总结:Indexer 就是注意力机制的"搜索引擎",先快速筛选出最相关的 2048 个 token,再做精细注意力。
5.1 Indexer 类结构
位置: model.py:435-487
作用: 从所有历史 token 中选择最相关的 top-k (2048) 个进行精细注意力。
1class Indexer(torch.nn.Module):
2 def __init__(self, args: ModelArgs):
3 super().__init__()
4 self.dim = args.dim # 7168
5 self.n_heads = args.index_n_heads # 64
6 self.n_local_heads = 64 // world_size # 8 (每 GPU)
7 self.head_dim = args.index_head_dim # 128
8 self.rope_head_dim = args.qk_rope_head_dim # 64
9 self.index_topk = args.index_topk # 2048
10 self.q_lora_rank = args.q_lora_rank # 1536
11
12 # Indexer 自己的投影层
13 self.wq_b = Linear(1536, 64*128) # Q 投影 (使用 MLA 的 qr)
14 self.wk = Linear(7168, 128) # K 投影 (从原始 x)
15 self.k_norm = LayerNorm(128) # K 归一化
16 self.weights_proj = Linear(7168, 64, dtype=torch.float32) # 权重投影
17
18 # Indexer 的 K Cache (FP8)
19 self.register_buffer("k_cache", torch.zeros(8, 163840, 128, dtype=torch.float8_e4m3fn))
20 self.register_buffer("k_scale_cache", torch.zeros(8, 163840, 1)) # 128/128=1 个 scale
5.2 前向传播流程
flowchart TB
subgraph "Indexer Forward"
X["x: [B,S,7168]"] --> WK["wk(x)"]
WK --> K_Norm["k_norm"]
K_Norm --> K["k: [B,S,128]"]
K --> K_Split["split(64, 64)"]
K_Split --> K_PE["k_pe: 64"]
K_Split --> K_NOPE["k_nope: 64"]
K_PE --> RoPE_K["RoPE (non-interleaved)"]
RoPE_K --> K_Cat["concat"]
K_NOPE --> K_Cat
K_Cat --> Hadamard_K["Hadamard 变换"]
Hadamard_K --> Quant_K["FP8 量化"]
Quant_K --> K_Cache["存入 k_cache"]
QR["qr: [B,S,1536]
(from MLA)"] --> WQ_B["wq_b(qr)"]
WQ_B --> Q["q: [B,S,64,128]"]
Q --> Q_Split["split(64, 64)"]
Q_Split --> Q_PE["q_pe: 64"]
Q_Split --> Q_NOPE["q_nope: 64"]
Q_PE --> RoPE_Q["RoPE (non-interleaved)"]
RoPE_Q --> Q_Cat["concat"]
Q_NOPE --> Q_Cat
Q_Cat --> Hadamard_Q["Hadamard 变换"]
Hadamard_Q --> Quant_Q["FP8 量化"]
Quant_Q --> FP8_Index["fp8_index Kernel"]
K_Cache --> FP8_Index
X --> Weights_Proj["weights_proj(x)"]
Weights_Proj --> Weights["weights: [B,S,64]"]
Weights --> FP8_Index
FP8_Index --> Scores["index_scores: [B,S,T]"]
Scores --> Mask["+ causal mask"]
Mask --> TopK["top-k(2048)"]
TopK --> Indices["topk_indices: [B,S,2048]"]
end
源码逐行解析:
1def forward(self, x, qr, start_pos, freqs_cis, mask):
2 bsz, seqlen, _ = x.size()
3 end_pos = start_pos + seqlen
4
5 # === Q 计算 ===
6 q = self.wq_b(qr) # [B,S,64*128]
7 q = q.view(bsz, seqlen, self.n_heads, self.head_dim) # [B,S,64,128]
8 q_pe, q_nope = torch.split(q, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
9 # 注意: Indexer 的 RoPE 是非交错的
10 q_pe = apply_rotary_emb(q_pe, freqs_cis, False)
11 q = torch.cat([q_pe, q_nope], dim=-1)
12
13 # === K 计算 ===
14 k = self.wk(x) # [B,S,128]
15 k = self.k_norm(k) # LayerNorm
16 k_pe, k_nope = torch.split(k, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
17 k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis, False).squeeze(2)
18 k = torch.cat([k_pe, k_nope], dim=-1)
19
20 # === Hadamard 变换 ===
21 q = rotate_activation(q) # 随机正交变换
22 k = rotate_activation(k)
23
24 # === FP8 量化 ===
25 q_fp8, q_scale = act_quant(q, block_size, self.scale_fmt)
26 k_fp8, k_scale = act_quant(k, block_size, self.scale_fmt)
27
28 # === 更新 K Cache ===
29 self.k_cache[:bsz, start_pos:end_pos] = k_fp8
30 self.k_scale_cache[:bsz, start_pos:end_pos] = k_scale
31
32 # === 计算 Index Score ===
33 weights = self.weights_proj(x.float()) * self.n_heads ** -0.5
34 weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
35
36 # FP8 矩阵乘 + ReLU + 求和
37 index_score = fp8_index(q_fp8.contiguous(), weights,
38 self.k_cache[:bsz, :end_pos].contiguous(),
39 self.k_scale_cache[:bsz, :end_pos].contiguous())
40
41 if mask is not None:
42 index_score += mask # 加上因果 mask
43
44 # === Top-K 选择 ===
45 topk_indices = index_score.topk(min(self.index_topk, end_pos), dim=-1)[1]
46
47 # 同步验证 (确保所有 GPU 选择相同的 token)
48 topk_indices_ = topk_indices.clone()
49 dist.broadcast(topk_indices_, src=0)
50 assert torch.all(topk_indices == topk_indices_)
51
52 return topk_indices
5.3 Hadamard 变换
位置: model.py:428-432
作用: 对 Q/K 进行随机正交变换,使它们的分布更均匀,有利于 FP8 量化。
1def rotate_activation(x: torch.Tensor) -> torch.Tensor:
2 assert x.dtype == torch.bfloat16
3 from fast_hadamard_transform import hadamard_transform
4 hidden_size = x.size(-1)
5 return hadamard_transform(x, scale=hidden_size ** -0.5)
Hadamard 变换:
$$H_n = \frac{1}{\sqrt{n}} \begin{pmatrix} H_{n/2} & H_{n/2} \ H_{n/2} & -H_{n/2} \end{pmatrix}$$
特点:
- 正交变换,保持向量范数
- 快速算法 O(n log n)
- 使激活值分布更均匀
6. 前馈网络
6.1 MLP 结构
位置: model.py:611-643
flowchart LR
X["x: [B,S,D]"] --> W1["w1 (gate_proj)
ColumnParallel"]
X --> W3["w3 (up_proj)
ColumnParallel"]
W1 --> SiLU["SiLU"]
SiLU --> Mul["×"]
W3 --> Mul
Mul --> W2["w2 (down_proj)
RowParallel"]
W2 --> Y["y: [B,S,D]"]
源码:
1class MLP(nn.Module):
2 def __init__(self, dim: int, inter_dim: int, reduce_output: bool = True):
3 super().__init__()
4 self.w1 = ColumnParallelLinear(dim, inter_dim) # gate_proj
5 self.w2 = RowParallelLinear(inter_dim, dim, reduce_output=reduce_output) # down_proj
6 self.w3 = ColumnParallelLinear(dim, inter_dim) # up_proj
7
8 def forward(self, x: torch.Tensor) -> torch.Tensor:
9 # SwiGLU: SiLU(W1·x) × W3·x
10 return self.w2(
11 (F.silu(self.w1(x).float()) * self.w3(x).float()).type_as(x)
12 )
6.2 SwiGLU 激活
公式:
$$\text{SwiGLU}(x) = \text{SiLU}(W_1 \cdot x) \odot (W_3 \cdot x)$$
$$\text{SiLU}(x) = x \cdot \sigma(x) = \frac{x}{1 + e^{-x}}$$
为什么用 SwiGLU?
- 比 ReLU/GELU 更平滑
- 门控机制提供更强的表达能力
- 在 LLM 中效果更好
7. MoE 混合专家系统
7.0 通俗解释:256 个专家的"会诊"系统
问题:一个模型难以精通所有领域
想象你去医院看病:
- 感冒找内科医生
- 骨折找骨科医生
- 心脏问题找心内科医生
每个医生都是某个领域的"专家"。如果只有一个全科医生,他很难在所有领域都做到顶尖。
MoE 的核心思想:让专家"会诊"
DeepSeek-V3.2 有 256 个专家 + 1 个共享专家:
1256 个路由专家:每个专家擅长某类问题
2 - 专家 0-31: 可能擅长数学
3 - 专家 32-63: 可能擅长代码
4 - 专家 64-95: 可能擅长语言理解
5 - ... (自动学习的,无需人工指定)
6
71 个共享专家:处理通用信息,所有 token 都会经过
关键点:每个 token 只激活 8 个专家,不是全部 256 个!
路由决策流程
flowchart TB
subgraph "分组路由策略"
Token["输入 token"] --> Gate["计算 256 个专家的分数"]
Gate --> Group["分成 8 组,每组 32 个专家"]
Group --> G1["组1: 专家 0-31"]
Group --> G2["组2: 专家 32-63"]
Group --> G8["组8: 专家 224-255"]
G1 --> GS1["组1 得分"]
G2 --> GS2["组2 得分"]
G8 --> GS8["组8 得分"]
GS1 --> Top4["选出得分最高的 4 组"]
GS2 --> Top4
GS8 --> Top4
Top4 --> Top8["从这 4 组中选 8 个专家"]
end
为什么分组?
- 直接从 256 个里选 8 个 → 可能全挤在某几个专家
- 分组后选 → 保证专家的多样性,负载更均衡
具体例子:假设输入是一道数学编程题
1Step 1: 计算 256 个专家的分数
2 专家 5 (数学): 0.95
3 专家 40 (代码): 0.92
4 专家 67 (逻辑): 0.88
5 专家 102 (推理): 0.85
6 ... (其他专家分数较低)
7
8Step 2: 分组计算组得分 (每组取 top-2 求和)
9 组1 (0-31): 专家5=0.95 + 专家12=0.45 = 1.40
10 组2 (32-63): 专家40=0.92 + 专家55=0.50 = 1.42
11 组3 (64-95): 专家67=0.88 + 专家80=0.60 = 1.48
12 组4 (96-127): 专家102=0.85 + 专家110=0.55 = 1.40
13 ... (其他组得分较低)
14
15Step 3: 选出 top-4 组
16 ✓ 组3 (1.48)
17 ✓ 组2 (1.42)
18 ✓ 组1 (1.40)
19 ✓ 组4 (1.40)
20
21Step 4: 从这 4 组的 128 个专家中选 top-8
22 ✓ 专家 5, 40, 67, 102, 12, 55, 80, 110
23
24Step 5: 归一化权重并缩放 ×2.5
25 权重 = [0.18, 0.17, 0.16, 0.15, 0.10, 0.09, 0.08, 0.07]
专家计算流程
flowchart LR
subgraph "MoE Forward"
X["输入 x"] --> Gate["Gate: 选 8 个专家"]
Gate --> W["权重: [w1...w8]"]
Gate --> Idx["专家索引: [e1...e8]"]
X --> E1["专家 e1"]
X --> E2["专家 e2"]
X --> E8["专家 e8"]
E1 --> Mul1["× w1"]
E2 --> Mul2["× w2"]
E8 --> Mul8["× w8"]
Mul1 --> Sum["Σ 求和"]
Mul2 --> Sum
Mul8 --> Sum
X --> Shared["共享专家"]
Shared --> Sum
Sum --> Y["输出 y"]
end
共享专家 vs 路由专家
| 类型 | 数量 | 激活条件 | 作用 |
|---|---|---|---|
| 路由专家 | 256 | 被 Gate 选中才激活 | 处理特定领域 |
| 共享专家 | 1 | 所有 token 都激活 | 处理通用信息 |
类比:
- 路由专家 = 各科室的专家医生(按需会诊)
- 共享专家 = 全科医生(每个病人都先经过)
显存与计算的巧妙平衡
1参数量计算:
2- 如果用一个大 MLP: 7168 × 18432 × 3 ≈ 4 亿参数
3- 实际用 MoE:
4 - 256 个小专家: 256 × (7168 × 2048 × 3) ≈ 113 亿参数
5 - 每次只激活 8 个: 8 × ... ≈ 3.5 亿参数激活
6
7结果:参数量增加 28 倍,但计算量几乎不变!
一句话总结:MoE 就是让 256 个专家"分诊",每个 token 只找最相关的 8 个专家处理,既保持了大模型的能力,又控制了计算量。
7.1 Gate 路由器
位置: model.py:646-709
作用: 为每个 token 选择 top-k 个专家。
flowchart TB
subgraph "Gate Forward"
X["x: [B×S, 7168]"] --> Linear["x @ weight.T"]
Linear --> Scores["scores: [B×S, 256]"]
Scores --> Sigmoid["sigmoid"]
Sigmoid --> OrigScores["original_scores"]
Scores --> AddBias["+ bias"]
AddBias --> GroupView["view: [B×S, 8, 32]"]
GroupView --> GroupTopK["每组 top-2 求和"]
GroupTopK --> GroupScores["group_scores: [B×S, 8]"]
GroupScores --> SelectGroups["top-4 组"]
SelectGroups --> Mask["mask 其他组"]
Mask --> Flatten["flatten: [B×S, 256]"]
Flatten --> TopK8["top-8 专家"]
TopK8 --> Indices["indices: [B×S, 8]"]
OrigScores --> Gather["gather"]
Indices --> Gather
Gather --> Weights["weights"]
Weights --> Normalize["归一化"]
Normalize --> Scale["× 2.5"]
Scale --> FinalWeights["final_weights: [B×S, 8]"]
end
源码逐行解析:
1class Gate(nn.Module):
2 def __init__(self, args: ModelArgs):
3 super().__init__()
4 self.dim = args.dim # 7168
5 self.topk = args.n_activated_experts # 8
6 self.n_groups = args.n_expert_groups # 8
7 self.topk_groups = args.n_limited_groups # 4
8 self.score_func = args.score_func # "sigmoid"
9 self.route_scale = args.route_scale # 2.5
10
11 self.weight = nn.Parameter(torch.empty(256, 7168)) # 路由权重
12 # bias 只在 dim=7168 时启用 (即完整模型)
13 self.bias = nn.Parameter(torch.empty(256, dtype=torch.float32)) if self.dim == 7168 else None
14
15 def forward(self, x):
16 # 1. 计算原始分数
17 scores = linear(x.float(), self.weight.float()) # [B*S, 256]
18
19 # 2. Sigmoid 激活
20 scores = scores.sigmoid()
21 original_scores = scores # 保存用于最终权重
22
23 # 3. 添加 bias (用于调整专家偏好)
24 if self.bias is not None:
25 scores = scores + self.bias
26
27 # 4. 分组路由 (256 专家分成 8 组,每组 32 个)
28 if self.n_groups > 1:
29 scores = scores.view(x.size(0), self.n_groups, -1) # [B*S, 8, 32]
30
31 # 每组取 top-2 专家分数之和作为组分数
32 group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1) # [B*S, 8]
33
34 # 选择 top-4 组
35 indices = group_scores.topk(self.topk_groups, dim=-1)[1] # [B*S, 4]
36
37 # mask 掉未选中的组
38 mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool)
39 mask = mask.scatter_(1, indices, False)
40 scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf"))
41 scores = scores.flatten(1) # [B*S, 256]
42
43 # 5. 从选中的组中选 top-8 专家
44 indices = scores.topk(self.topk, dim=-1)[1] # [B*S, 8]
45
46 # 6. 从原始分数中获取权重 (不是 mask 后的分数)
47 weights = original_scores.gather(1, indices) # [B*S, 8]
48
49 # 7. 归一化 (sigmoid 模式下)
50 weights /= weights.sum(dim=-1, keepdim=True)
51
52 # 8. 缩放
53 weights *= self.route_scale # × 2.5
54
55 return weights, indices
7.2 Expert 专家层
位置: model.py:712-744
结构: 与 MLP 相同,但使用普通 Linear (不并行)
1class Expert(nn.Module):
2 def __init__(self, dim: int, inter_dim: int):
3 super().__init__()
4 self.w1 = Linear(dim, inter_dim) # 7168 → 2048
5 self.w2 = Linear(inter_dim, dim) # 2048 → 7168
6 self.w3 = Linear(dim, inter_dim) # 7168 → 2048
7
8 def forward(self, x):
9 return self.w2((F.silu(self.w1(x).float()) * self.w3(x).float()).type_as(x))
7.3 MoE 整体结构
位置: model.py:747-804
flowchart TB
subgraph "MoE Forward"
X["x: [B,S,7168]"] --> Flatten["flatten: [B×S, 7168]"]
Flatten --> Gate["Gate"]
Gate --> Weights["weights: [B×S, 8]"]
Gate --> Indices["indices: [B×S, 8]"]
Flatten --> Shared["shared_experts
(MLP)"]
Shared --> Y_Shared["y_shared"]
subgraph "并行专家计算"
Indices --> Loop["for i in local_experts"]
Loop --> Find["找到选中专家 i 的 token"]
Find --> Expert_i["Expert_i(x[idx])"]
Expert_i --> Weighted["× weights"]
Weighted --> Accumulate["累加到 y"]
end
Y_Shared --> Add["y += y_shared"]
Accumulate --> Add
Add --> AllReduce["all_reduce"]
AllReduce --> Output["y: [B,S,7168]"]
end
源码逐行解析:
1class MoE(nn.Module):
2 def __init__(self, args: ModelArgs):
3 super().__init__()
4 self.dim = args.dim
5 self.n_routed_experts = args.n_routed_experts # 256
6 self.n_local_experts = args.n_routed_experts // world_size # 32 (每 GPU)
7 self.n_activated_experts = args.n_activated_experts # 8
8
9 # 当前 GPU 负责的专家范围
10 self.experts_start_idx = rank * self.n_local_experts
11 self.experts_end_idx = self.experts_start_idx + self.n_local_experts
12
13 self.gate = Gate(args)
14
15 # 只创建本 GPU 负责的专家
16 self.experts = nn.ModuleList([
17 Expert(args.dim, args.moe_inter_dim)
18 if self.experts_start_idx <= i < self.experts_end_idx
19 else None
20 for i in range(self.n_routed_experts)
21 ])
22
23 # 共享专家 (所有 token 都经过)
24 self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim,
25 reduce_output=False)
26
27 def forward(self, x):
28 shape = x.size()
29 x = x.view(-1, self.dim) # [B*S, 7168]
30
31 # 路由
32 weights, indices = self.gate(x) # [B*S, 8], [B*S, 8]
33
34 y = torch.zeros_like(x, dtype=torch.float32)
35
36 # 统计每个专家被选中的次数
37 counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
38
39 # 只计算本 GPU 负责的专家
40 for i in range(self.experts_start_idx, self.experts_end_idx):
41 if counts[i] == 0:
42 continue
43 expert = self.experts[i]
44 # 找到选中专家 i 的 (token_idx, top_idx)
45 idx, top = torch.where(indices == i)
46 # 计算并加权累加
47 y[idx] += expert(x[idx]) * weights[idx, top, None]
48
49 # 加上共享专家输出
50 y += self.shared_experts(x)
51
52 # 跨 GPU 合并
53 if world_size > 1:
54 dist.all_reduce(y)
55
56 return y.type_as(x).view(shape)
8. Block 与 Transformer
8.0 通俗解释:61 层"加工流水线"
Transformer 就是一条流水线
想象一个工厂的生产线:
- 原材料(token)从一端进入
- 经过 61 道工序(Block)加工
- 最后输出成品(预测的下一个词)
每道工序都做两件事:
- MLA 注意力:让这个位置"看看"其他位置的信息
- FFN/MoE:对信息进行"深加工"
flowchart LR
subgraph "Transformer 流水线"
Input["输入 tokens"] --> Embed["词嵌入"]
Embed --> B0["Block 0
(Dense)"]
B0 --> B1["Block 1
(Dense)"]
B1 --> B2["Block 2
(Dense)"]
B2 --> B3["Block 3
(MoE)"]
B3 --> Dots["..."]
Dots --> B60["Block 60
(MoE)"]
B60 --> Norm["最终归一化"]
Norm --> Head["输出投影"]
Head --> Output["预测 logits"]
end
为什么前 3 层用 Dense,后面用 MoE?
| 层级 | 类型 | 原因 |
|---|---|---|
| 0-2 层 | Dense MLP | 初期需要建立基础表示,Dense 更稳定 |
| 3-60 层 | MoE | 后期需要专业化处理,MoE 更高效 |
类比:
- 前 3 层 = 通识教育(所有学生学一样的基础课)
- 后 58 层 = 专业教育(不同学生选不同的专业课)
残差连接:信息的"高速公路"
每个 Block 都有残差连接,让信息可以"跳过"某些层直接传递:
1传统方式:x → Block → y
2残差方式:x → Block → y + x (输出 = 处理结果 + 原始输入)
为什么需要残差?
- 防止梯度消失(61 层太深了!)
- 让模型可以学习"不改变"(如果 Block 输出 0,结果就是原始输入)
- 让浅层特征可以直接传到深层
flowchart LR
subgraph "残差连接"
X["输入 x"] --> Block["Block 处理"]
Block --> Add["+"]
X --> Add
Add --> Y["输出 y = x + Block(x)"]
end
DeepSeek 的"融合残差"优化
传统残差需要两次加法(注意力后一次,FFN 后一次)。DeepSeek 把残差融合到 RMSNorm 里:
1# 传统方式 (需要存储中间结果)
2x = x + attention(norm(x)) # 第一次加法
3x = x + ffn(norm(x)) # 第二次加法
4
5# 融合方式 (节省显存)
6x, residual = norm(x, residual) # 归一化同时处理残差
7x = attention(x)
8x, residual = norm(x, residual) # 归一化同时处理残差
9x = ffn(x)
完整的前向传播流程
sequenceDiagram
participant T as Tokens
participant E as Embedding
participant B as Block (×61)
participant N as Final Norm
participant H as Head
T->>E: [B, S] 词 ID
E->>B: [B, S, 7168] 向量
loop 61 次
B->>B: RMSNorm + MLA
B->>B: RMSNorm + FFN/MoE
end
B->>N: [B, S, 7168]
N->>H: [B, 7168] (只取最后一个位置)
H->>T: [B, 129280] logits
一句话总结:Transformer 是 61 层的加工流水线,每层先"看上下文"(MLA),再"深加工"(FFN/MoE),通过残差连接保证信息流通。
8.1 Block 结构
位置: model.py:807-851
flowchart TB
subgraph "Block Forward"
X["x"] --> CheckRes{"residual is None?"}
Res["residual"] --> CheckRes
CheckRes -->|Yes| Init["x, residual = attn_norm(x), x"]
CheckRes -->|No| Fused["x, residual = attn_norm(x, residual)"]
Init --> MLA["MLA(x)"]
Fused --> MLA
MLA --> X2["x (attention output)"]
X2 --> FFN_Norm["x, residual = ffn_norm(x, residual)"]
Res2["residual"] --> FFN_Norm
FFN_Norm --> FFN["MLP / MoE"]
FFN --> X3["x (ffn output)"]
FFN_Norm --> Res3["residual (updated)"]
X3 --> Output["return (x, residual)"]
Res3 --> Output
end
源码:
1class Block(nn.Module):
2 def __init__(self, layer_id: int, args: ModelArgs):
3 super().__init__()
4 self.attn = MLA(args)
5 # 前 3 层用 Dense MLP,之后用 MoE
6 self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)
7 self.attn_norm = RMSNorm(args.dim)
8 self.ffn_norm = RMSNorm(args.dim)
9
10 def forward(self, x, residual, start_pos, freqs_cis, mask):
11 # 注意力前归一化 + 残差处理
12 if residual is None:
13 x, residual = self.attn_norm(x), x
14 else:
15 x, residual = self.attn_norm(x, residual)
16
17 # 注意力
18 x = self.attn(x, start_pos, freqs_cis, mask)
19
20 # FFN 前归一化 + 残差处理
21 x, residual = self.ffn_norm(x, residual)
22
23 # FFN
24 x = self.ffn(x)
25
26 return x, residual
8.2 Transformer 主类
位置: model.py:854-913
flowchart TB
subgraph "Transformer Forward"
Tokens["tokens: [B, S]"] --> Embed["ParallelEmbedding"]
Embed --> H["h: [B, S, 7168]"]
H --> Init["residual = None"]
Init --> FreqsCis["freqs_cis = self.freqs_cis[start:end]"]
FreqsCis --> Mask["mask = triu(-inf) if S>1 else None"]
Mask --> Loop["for layer in self.layers"]
Loop --> Block["h, residual = layer(h, residual, ...)"]
Block --> Loop
Block --> FinalNorm["h, _ = norm(h, residual)"]
FinalNorm --> Head["head(h[:, -1])"]
Head --> Logits["logits: [B, vocab_size/8]"]
Logits --> AllGather["all_gather"]
AllGather --> FullLogits["logits: [B, vocab_size]"]
end
源码:
1class Transformer(nn.Module):
2 def __init__(self, args: ModelArgs):
3 global world_size, rank
4 world_size = dist.get_world_size() if dist.is_initialized() else 1
5 rank = dist.get_rank() if dist.is_initialized() else 0
6
7 # 设置全局精度
8 Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
9 Linear.scale_fmt = args.scale_fmt
10
11 super().__init__()
12 self.max_seq_len = args.max_seq_len
13 self.embed = ParallelEmbedding(args.vocab_size, args.dim)
14
15 # 61 层 Block
16 self.layers = torch.nn.ModuleList()
17 for layer_id in range(args.n_layers):
18 self.layers.append(Block(layer_id, args))
19
20 self.norm = RMSNorm(args.dim)
21 # lm_head 用 FP32 以获得更精确的 logits
22 self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.float32)
23
24 # 预计算位置编码
25 self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
26
27 @torch.inference_mode()
28 def forward(self, tokens: torch.Tensor, start_pos: int = 0):
29 seqlen = tokens.size(1)
30
31 # 获取当前位置的频率
32 freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
33
34 # 因果 mask (只在 prefill 时使用)
35 mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) if seqlen > 1 else None
36
37 # 嵌入
38 h, residual = self.embed(tokens), None
39
40 # 61 层 Block
41 for layer in self.layers:
42 h, residual = layer(h, residual, start_pos, freqs_cis, mask)
43
44 # 最终归一化
45 h, _ = self.norm(h, residual)
46
47 # 输出投影 (只取最后一个 token)
48 logits = self.head(h[:, -1].float())
49
50 # 合并所有 GPU 的 logits
51 if world_size > 1:
52 all_logits = [torch.empty_like(logits) for _ in range(world_size)]
53 dist.all_gather(all_logits, logits)
54 logits = torch.cat(all_logits, dim=-1)
55
56 return logits
8.3 残差连接策略
DeepSeek-V3.2 使用"融合残差"策略:
flowchart LR
subgraph "传统残差"
X1[x] --> Norm1[Norm]
Norm1 --> Attn1[Attention]
Attn1 --> Add1["+"]
X1 --> Add1
Add1 --> Norm2[Norm]
Norm2 --> FFN1[FFN]
FFN1 --> Add2["+"]
Add1 --> Add2
end
subgraph "融合残差 (节省显存)"
X2[x] --> FusedNorm1["Norm + 加残差"]
Res[residual] --> FusedNorm1
FusedNorm1 --> Attn2[Attention]
FusedNorm1 -->|new residual| FusedNorm2["Norm + 加残差"]
Attn2 --> FusedNorm2
FusedNorm2 --> FFN2[FFN]
FusedNorm2 -->|new residual| Out[output]
end
9. FP8 量化与 CUDA Kernel
位置: kernel.py
实现说明:本仓库的 kernel 由
tilelang.jit定义(运行时/首次调用时编译),并非手写.cu文件;同时通过pass_configs关闭部分优化选项以保证兼容性与稳定性。
9.1 act_quant 激活量化
flowchart TB
subgraph "act_quant"
X["x: BF16 [M, N]"] --> Reshape["reshape: [M, N/128, 128]"]
Reshape --> AbsMax["absmax per block"]
AbsMax --> Scale["scale = absmax / 448"]
Scale --> Clamp["x_fp8 = clamp(x/scale, -448, 448)"]
Clamp --> Y["y: FP8 [M, N]"]
Scale --> S["scale: FP32 [M, N/128]"]
end
9.2 fp8_gemm FP8 矩阵乘法
flowchart TB
subgraph "fp8_gemm"
A["A: FP8 [M, K]"] --> GEMM["Tensor Core GEMM"]
B["B: FP8 [N, K]"] --> GEMM
GEMM --> Acc["累加器: FP32"]
SA["scale_A: [M, K/128]"] --> ScaleMul["scale_A × scale_B"]
SB["scale_B: [N/128, K/128]"] --> ScaleMul
ScaleMul --> Acc
Acc --> C["C: BF16 [M, N]"]
end
9.3 fp8_index:Indexer 用的 FP8 打分 Kernel
Indexer 会把 q/k 量化到 FP8,并调用 fp8_index 计算稀疏选择所需的打分矩阵(输出 FP32 分数,供后续 top-k 选择)。
10. 推理流程
位置: generate.py
10.1 完整推理时序图
sequenceDiagram
participant User as 用户
participant Main as main()
participant Gen as generate()
participant TF as Transformer
participant Sample as sample()
User->>Main: 启动推理
Main->>Main: 加载 ModelArgs 配置(如 inference/config_671B_v3.2.json)
Main->>TF: Transformer(args)
Main->>Main: load_model (safetensors)
Main->>Main: 加载 tokenizer
loop 交互循环
User->>Main: 输入 prompt(多卡时仅 rank0 读取)
Main->>Main: 多卡同步 prompt(dist.broadcast_object_list)
Main->>Main: apply_chat_template
Main->>Gen: generate(model, prompt_tokens, ...)
rect rgb(200, 230, 255)
Note over Gen,TF: Prefill 阶段
Gen->>TF: forward(tokens[:, 0:len], start_pos=0)
TF-->>Gen: logits
end
rect rgb(255, 230, 200)
Note over Gen,TF: Decode 阶段
loop until EOS
Gen->>Sample: sample(logits, temp)
Sample-->>Gen: next_token
Gen->>TF: forward(next_token, start_pos=cur)
TF-->>Gen: logits
end
end
Gen-->>Main: completion_tokens
Main-->>User: decoded text
end
10.2 sample 函数
1def sample(logits, temperature: float = 1.0):
2 logits = logits / max(temperature, 1e-5)
3 probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
4 # Gumbel-Max trick: 等价于 multinomial 但更高效
5 return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
11. 权重转换流程
位置: convert.py
11.1 转换流程图
flowchart TB
subgraph "HuggingFace 格式"
HF["163 个 safetensors 文件"]
end
subgraph "convert.py"
HF --> Load["加载每个文件"]
Load --> Skip["跳过 layer.61 (MTP)"]
Skip --> Rename["重命名 keys"]
Rename --> Shard["按 MP 分片"]
end
subgraph "输出"
Shard --> Out["model{0-7}-mp8.safetensors"]
end
11.2 Key 映射表
| HuggingFace 名称 | 转换后名称 | 分片 |
|---|---|---|
embed_tokens |
embed |
dim=0 |
input_layernorm |
attn_norm |
None |
post_attention_layernorm |
ffn_norm |
None |
q_proj |
wq |
dim=0 |
q_a_proj |
wq_a |
None |
q_a_layernorm |
q_norm |
None |
q_b_proj |
wq_b |
dim=0 |
kv_a_proj_with_mqa |
wkv_a |
None |
kv_a_layernorm |
kv_norm |
None |
kv_b_proj |
wkv_b |
dim=0 |
o_proj |
wo |
dim=1 |
gate |
gate |
None |
gate_proj |
w1 |
dim=0 |
up_proj |
w3 |
dim=0 |
down_proj |
w2 |
dim=1 |
norm |
norm |
None |
lm_head |
head |
dim=0 |
wk |
wk |
None |
k_norm |
k_norm |
None |
weights_proj |
weights_proj |
None |
额外重命名规则(实现中直接
str.replace):去掉前缀model.、self_attn → attn、mlp → ffn、weight_scale_inv → scale、e_score_correction_bias → bias。另外会跳过model.layers.61.*(MTP 层)。
附录 A: 张量维度速查表
| 组件 | 张量 | 维度 (DeepSeek-V3.2) |
|---|---|---|
| Embedding | weight | [129280/8, 7168] |
| MLA | wq_a.weight | [1536, 7168] |
| wq_b.weight | [128×192/8, 1536] | |
| wkv_a.weight | [576, 7168] | |
| wkv_b.weight | [128×256/8, 512] | |
| wo.weight | [7168, 128×128/8] | |
| kv_cache | [8, 163840, 512] | |
| pe_cache | [8, 163840, 64] | |
| Indexer | wq_b.weight | [64×128, 1536] |
| wk.weight | [128, 7168] | |
| k_cache | [8, 163840, 128] FP8 | |
| MLP | w1.weight | [18432/8, 7168] |
| w2.weight | [7168, 18432/8] | |
| w3.weight | [18432/8, 7168] | |
| Expert | w1.weight | [2048, 7168] |
| w2.weight | [7168, 2048] | |
| w3.weight | [2048, 7168] | |
| Gate | weight | [256, 7168] |
| Head | weight | [129280/8, 7168] |
附录 B: 关键公式汇总
MLA 注意力
$$ \begin{aligned} c_Q &= \text{RMSNorm}(W_{QA} \cdot x) & \text{压缩 Q} \ Q &= W_{QB} \cdot c_Q & \text{展开 Q} \ [c_{KV}, k_{pe}] &= W_{KVA} \cdot x & \text{压缩 KV} \ c_{KV} &= \text{RMSNorm}(c_{KV}) & \text{归一化} \ [k_{nope}, v] &= W_{KVB} \cdot c_{KV} & \text{展开 KV} \ \text{Attn} &= \text{softmax}\left(\frac{QK^T}{\sqrt{d}} + M_{idx}\right) V \end{aligned} $$
MoE 路由
$$ \begin{aligned} s_i &= \sigma(W_{gate} \cdot x + b)i \ G_j &= \sum{i \in \text{group}_j} \text{top2}(s_i) \ \text{groups} &= \text{topk}(G, k=4) \ \text{experts} &= \text{topk}(s[\text{groups}], k=8) \ w &= \text{normalize}(s[\text{experts}]) \times 2.5 \end{aligned} $$
RMSNorm
$$\text{RMSNorm}(x) = \frac{x}{\sqrt{\frac{1}{d}\sum_i x_i^2 + \epsilon}} \cdot \gamma$$
SwiGLU
$$\text{SwiGLU}(x) = \text{SiLU}(W_1 x) \odot W_3 x$$
附录 C: 运行命令
转换权重
1cd inference
2python convert.py \
3 --hf-ckpt-path /path/to/hf/checkpoint \
4 --save-path /path/to/converted \
5 --n-experts 256 \
6 --model-parallel 8
单卡推理
1cd inference
2python generate.py \
3 --ckpt-path /path/to/converted \
4 --config config_671B_v3.2.json \
5 --interactive \
6 --temperature 0.6
多卡推理
1cd inference
2torchrun --nproc_per_node=8 generate.py \
3 --ckpt-path /path/to/converted \
4 --config config_671B_v3.2.json \
5 --interactive \
6 --temperature 0.6