Deepseek V3.2 Architecture Guide

Posted by NoPanic on Thu, Dec 11, 2025

DeepSeek-V3.2 源码架构详解

本文档面向初学者,详细解析 DeepSeek-V3.2 模型的源码实现,帮助读者理解其核心架构和运行机制。

精细度:本文档解析到每一个类、每一个函数、每一行关键代码的粒度。


目录

  1. 整体架构概览
  2. 基础组件详解
  3. 位置编码详解
  4. MLA 注意力机制
  5. Indexer 稀疏注意力
  6. 前馈网络
  7. MoE 混合专家系统
  8. Block 与 Transformer
  9. FP8 量化与 CUDA Kernel
  10. 推理流程
  11. 权重转换流程

1. 整体架构概览

1.1 模型架构图

graph TB
    subgraph "DeepSeek-V3.2 Transformer"
        Input[输入 Tokens] --> Embed[ParallelEmbedding
词嵌入层] Embed --> Block1[Block 0
Dense FFN] Block1 --> Block2[Block 1
Dense FFN] Block2 --> Block3[Block 2
Dense FFN] Block3 --> BlockN[Block 3-60
MoE FFN] BlockN --> Norm[RMSNorm
最终归一化] Norm --> Head[ColumnParallelLinear
输出投影] Head --> Output[输出 Logits] end subgraph "每个 Block 内部结构" direction LR X[输入 x] --> AttnNorm[RMSNorm] AttnNorm --> MLA[MLA 注意力] MLA --> FFNNorm[RMSNorm] FFNNorm --> FFN[MLP/MoE] FFN --> Y[输出] end

1.2 文件结构与代码行数

 1inference/
 2├── model.py      # 923 行 - 核心模型定义
 3├── generate.py   # 187 行 - 推理生成逻辑
 4├── kernel.py     # 275 行 - Kernel 实现(TileLang JIT:act_quant / fp8_gemm / fp8_index)
 5├── convert.py    # 101 行 - 权重转换脚本(HF → 本仓库推理格式)
 6├── config_671B_v3.2.json  # 671B 推理 demo 的 ModelArgs 配置(与 generate.py 匹配)
 7├── requirements.txt       # 推理依赖(含 tilelang / fast_hadamard_transform)
 8└── README.md              # 推理 demo 使用说明
 9encoding/
10└── encoding_dsv32.py  # Chat/Tool 调用的 prompt 编码(与推理代码独立)

说明:本文档中若写 model.py/generate.py/kernel.py/convert.py,默认指 inference/ 目录下同名文件。

1.3 类依赖关系图

classDiagram
    Transformer --> ParallelEmbedding
    Transformer --> Block
    Transformer --> RMSNorm
    Transformer --> ColumnParallelLinear

    Block --> MLA
    Block --> MLP
    Block --> MoE
    Block --> RMSNorm

    MLA --> Linear
    MLA --> ColumnParallelLinear
    MLA --> RowParallelLinear
    MLA --> RMSNorm
    MLA --> Indexer

    Indexer --> Linear
    Indexer --> LayerNorm

    MoE --> Gate
    MoE --> Expert
    MoE --> MLP

    Expert --> Linear

    MLP --> ColumnParallelLinear
    MLP --> RowParallelLinear

    ColumnParallelLinear --|> Linear
    RowParallelLinear --|> Linear

2. 基础组件详解

2.1 ModelArgs 配置类

位置: model.py:17-90

 1@dataclass
 2class ModelArgs:
 3    # 基础配置
 4    max_batch_size: int = 8           # 最大批次大小
 5    max_seq_len: int = 4096 * 4       # 最大序列长度 (16K,实际 V3.2 是 163K)
 6    dtype: Literal["bf16", "fp8"] = "bf16"  # 计算精度
 7    scale_fmt: Optional[str] = None   # FP8 激活量化 scale 的存储格式(如 ue8m0)
 8    vocab_size: int = 102400          # 词表大小 (实际 V3.2 是 129280)
 9    dim: int = 2048                   # 隐藏层维度 (实际 V3.2 是 7168)
10    inter_dim: int = 10944            # Dense MLP 的中间维度(实际 V3.2 是 18432)
11    moe_inter_dim: int = 1408         # MoE expert 的中间维度(实际 V3.2 是 2048)
12    n_layers: int = 27                # 层数 (实际 V3.2 是 61)
13    n_dense_layers: int = 1           # Dense 层数 (实际 V3.2 是 3)
14    n_heads: int = 16                 # 注意力头数 (实际 V3.2 是 128)
15
16    # MoE 配置
17    n_routed_experts: int = 64        # 路由专家数 (实际 V3.2 是 256)
18    n_shared_experts: int = 2         # 共享专家数 (实际 V3.2 是 1)
19    n_activated_experts: int = 6      # 激活专家数 (实际 V3.2 是 8)
20    n_expert_groups: int = 1          # 专家分组数 (实际 V3.2 是 8)
21    n_limited_groups: int = 1         # 选择的组数 (实际 V3.2 是 4)
22    score_func: str = "softmax"       # 评分函数 (实际 V3.2 是 sigmoid)
23    route_scale: float = 1.           # 路由缩放 (实际 V3.2 是 2.5)
24
25    # MLA 配置
26    q_lora_rank: int = 0              # Q 低秩维度 (实际 V3.2 是 1536)
27    kv_lora_rank: int = 512           # KV 低秩维度
28    qk_nope_head_dim: int = 128       # 无位置编码的 QK 维度
29    qk_rope_head_dim: int = 64        # 有位置编码的 QK 维度
30    v_head_dim: int = 128             # V 头维度
31
32    # YaRN 配置
33    original_seq_len: int = 4096      # 原始训练长度
34    rope_theta: float = 10000.0       # RoPE 基础频率
35    rope_factor: float = 40           # 长度外推因子
36    beta_fast: int = 32               # 快速衰减参数
37    beta_slow: int = 1                # 慢速衰减参数
38    mscale: float = 1.                # YaRN 额外注意力缩放(影响 softmax_scale)
39
40    # Indexer 配置
41    index_n_heads: int = 64           # Indexer 头数
42    index_head_dim: int = 128         # Indexer 头维度
43    index_topk: int = 2048            # 选择的 token 数

配置流程图:

flowchart LR
    subgraph "配置加载"
        JSON[inference/config_671B_v3.2.json] --> Parse[json.load]
        Parse --> Args[ModelArgs**kwargs]
        Args --> Model[Transformer]
    end

注意:仓库根目录的 config.json 是 HuggingFace Transformers 的模型配置(字段名如 hidden_size/num_hidden_layers/...),不能直接喂给 ModelArgs(**json.load(...))。本推理 demo 使用的是 inference/config_671B_v3.2.json(字段名与 ModelArgs 一一对应)。


2.2 ParallelEmbedding 并行嵌入层

位置: model.py:92-131

作用: 将词表按行切分到多个 GPU,实现词嵌入的张量并行。

2.2.1 初始化流程

flowchart TB
    subgraph "ParallelEmbedding.__init__"
        VS[vocab_size=129280] --> Check{vocab_size % world_size == 0?}
        Check -->|Yes| Calc[part_vocab_size = 129280/8 = 16160]
        Calc --> Range["vocab_start_idx = rank × 16160
vocab_end_idx = start + 16160"] Range --> Weight["weight = Parameter(16160, 7168)"] end

源码逐行解析:

 1class ParallelEmbedding(nn.Module):
 2    def __init__(self, vocab_size: int, dim: int):
 3        super().__init__()
 4        self.vocab_size = vocab_size  # 完整词表大小: 129280
 5        self.dim = dim                # 嵌入维度: 7168
 6
 7        # 检查词表是否可被 GPU 数整除
 8        assert vocab_size % world_size == 0
 9
10        # 计算每个 GPU 负责的词表范围
11        self.part_vocab_size = vocab_size // world_size  # 129280/8=16160
12        self.vocab_start_idx = rank * self.part_vocab_size  # GPU0: 0, GPU1: 16160, ...
13        self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
14
15        # 只存储本 GPU 负责的词表部分
16        self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))

2.2.2 关于 world_size 的说明

代码顶部的 world_size = 1 只是默认值,实际运行时会被动态修改:

 1# model.py:13-14 (默认值)
 2world_size = 1
 3rank = 0
 4
 5# model.py:873-875 (运行时修改)
 6class Transformer(nn.Module):
 7    def __init__(self, args: ModelArgs):
 8        global world_size, rank  # 修改全局变量
 9        world_size = dist.get_world_size() if dist.is_initialized() else 1
10        rank = dist.get_rank() if dist.is_initialized() else 0
运行场景 world_size part_vocab_size
单卡运行 1 129280 (完整词表)
8卡运行 (torchrun --nproc_per_node=8) 8 16160

2.2.3 前向传播通俗解释:分工合作查字典

想象词表就是一本 12 万词的大字典,太大了一个人查不过来,于是让 8 个 GPU 分工:

1GPU 0: 负责词 0 ~ 16159      (相当于 A-C 开头的词)
2GPU 1: 负责词 16160 ~ 32319  (相当于 D-F 开头的词)
3GPU 2: 负责词 32320 ~ 48479  (相当于 G-I 开头的词)
4...
5GPU 7: 负责词 113120 ~ 129279 (相当于 X-Z 开头的词)

具体例子:假设输入是 x = [100, 20000, 50000](3 个词的 ID)


Step 1: 每个 GPU 判断"这是不是我的活"

1mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
GPU 负责范围 词 100 词 20000 词 50000 mask 结果
GPU 0 0~16159 ✓ 我的 ✗ 不是 ✗ 不是 [False, True, True]
GPU 1 16160~32319 ✗ 不是 ✓ 我的 ✗ 不是 [True, False, True]
GPU 3 48480~64639 ✗ 不是 ✗ 不是 ✓ 我的 [True, True, False]

Step 2: 转换成"本地索引"

1x = x - self.vocab_start_idx  # 减去起始位置
2x[mask] = 0                   # 不是我的活就填 0(防止数组越界)
GPU 原始 x 减去起始偏移 不是我的填 0
GPU 0 (起始=0) [100, 20000, 50000] [100, 20000, 50000] [100, 0, 0]
GPU 1 (起始=16160) [100, 20000, 50000] [-16060, 3840, 33840] [0, 3840, 0]
GPU 3 (起始=48480) [100, 20000, 50000] [-48380, -28480, 1520] [0, 0, 1520]

Step 3: 各自查表

1y = F.embedding(x, self.weight)
GPU 查表结果
GPU 0 [词100的向量, 词0的向量(垃圾), 词0的向量(垃圾)]
GPU 1 [词0的向量(垃圾), 词20000的向量, 词0的向量(垃圾)]
GPU 3 [词0的向量(垃圾), 词0的向量(垃圾), 词50000的向量]

Step 4: 把"垃圾数据"清零

1y[mask] = 0  # 不是我负责的位置,结果清零
GPU 清零后的结果
GPU 0 [词100的向量, 0, 0]
GPU 1 [0, 词20000的向量, 0]
GPU 3 [0, 0, 词50000的向量]

Step 5: 汇总所有 GPU 的结果

1dist.all_reduce(y)  # 所有 GPU 的结果相加
1  [词100的向量,  0,            0           ]   ← GPU 0 的贡献
2+ [0,           词20000的向量, 0           ]   ← GPU 1 的贡献
3+ [0,           0,            词50000的向量]   ← GPU 3 的贡献
4+ ...                                          ← 其他 GPU 贡献 0
5─────────────────────────────────────────────────────────────────
6= [词100的向量, 词20000的向量, 词50000的向量]   ← 最终结果

流程图总结

flowchart TB
    subgraph "输入"
        X["x = [100, 20000, 50000]
3个词需要查询"] end subgraph "GPU 0 (负责词 0-16159)" M0["Step1: 判断
mask = [F, T, T]
只有100是我的"] L0["Step2: 本地索引
[100, 0, 0]"] E0["Step3: 查表
[emb_100, 垃圾, 垃圾]"] C0["Step4: 清零
[emb_100, 0, 0]"] end subgraph "GPU 1 (负责词 16160-32319)" M1["Step1: 判断
mask = [T, F, T]
只有20000是我的"] L1["Step2: 本地索引
[0, 3840, 0]"] E1["Step3: 查表
[垃圾, emb_20000, 垃圾]"] C1["Step4: 清零
[0, emb_20000, 0]"] end subgraph "GPU 3 (负责词 48480-64639)" M3["Step1: 判断
mask = [T, T, F]
只有50000是我的"] L3["Step2: 本地索引
[0, 0, 1520]"] E3["Step3: 查表
[垃圾, 垃圾, emb_50000]"] C3["Step4: 清零
[0, 0, emb_50000]"] end X --> M0 --> L0 --> E0 --> C0 X --> M1 --> L1 --> E1 --> C1 X --> M3 --> L3 --> E3 --> C3 C0 --> AR["Step5: all_reduce
所有结果相加"] C1 --> AR C3 --> AR AR --> Y["最终结果
y = [emb_100, emb_20000, emb_50000]"]

一句话总结:每个 GPU 只查自己负责的词,查到了就返回向量,查不到就返回 0,最后把所有 GPU 的结果加起来。

2.2.4 源码逐行解析

 1def forward(self, x: torch.Tensor) -> torch.Tensor:
 2    if world_size > 1:
 3        # 1. 创建 mask: 标记不属于本 GPU 的 token
 4        mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
 5
 6        # 2. 转换为本地索引 (减去起始偏移)
 7        x = x - self.vocab_start_idx
 8
 9        # 3. 不属于本 GPU 的 token 索引设为 0 (防止越界)
10        x[mask] = 0
11
12    # 4. 查表得到嵌入
13    y = F.embedding(x, self.weight)
14
15    if world_size > 1:
16        # 5. 不属于本 GPU 的结果清零
17        y[mask] = 0
18
19        # 6. 所有 GPU 结果求和 (all_reduce)
20        dist.all_reduce(y)
21
22    return y

2.2.5 显存对比

配置 词表显存
单卡完整词表 129280 × 7168 × 2B = 1.73 GB
8卡并行 (每卡) 16160 × 7168 × 2B = 216 MB

2.3 Linear 及并行变体

位置: model.py:134-269

2.3.1 通俗解释:分工合作算矩阵

想象你要计算一个超大的矩阵乘法 y = x × W,其中 W 是一个 7168 × 18432 的巨型矩阵(约 2.6 亿个数字)。单卡显存放不下,怎么办?

两种切分策略


策略一:列并行 (ColumnParallelLinear) - “每人算一部分结果”

1想象你要做一道数学题:计算 [1, 2, 3] × [[a, b, c, d],
2                                      [e, f, g, h],
3                                      [i, j, k, l]]
4
5结果是 [y1, y2, y3, y4]
6
7列并行就是让 2 个 GPU 分工:
8- GPU 0: 只算 [y1, y2]  (对应矩阵的前 2 列)
9- GPU 1: 只算 [y3, y4]  (对应矩阵的后 2 列)
flowchart LR
    subgraph "列并行:输入相同,输出切分"
        X["x: [B,S,7168]
完整输入"] --> GPU0["GPU0: W的前半列
[7168, 9216]"] X --> GPU1["GPU1: W的后半列
[7168, 9216]"] GPU0 --> Y0["y0: [B,S,9216]"] GPU1 --> Y1["y1: [B,S,9216]"] end

特点

  • ✅ 输入不需要通信(每个 GPU 都有完整的 x)
  • ✅ 输出直接使用(不需要 all_reduce)
  • ⚠️ 输出是"切片",后续层要知道怎么用

策略二:行并行 (RowParallelLinear) - “每人算一部分加数”

 1想象你要做一道数学题:计算 [1, 2, 3, 4] × [[a],
 2                                          [b],
 3                                          [c],
 4                                          [d]]
 5
 6结果是 y = 1×a + 2×b + 3×c + 4×d
 7
 8行并行就是让 2 个 GPU 分工:
 9- GPU 0: 计算 1×a + 2×b (输入的前半部分 × 权重的前半行)
10- GPU 1: 计算 3×c + 4×d (输入的后半部分 × 权重的后半行)
11- 最后:all_reduce 求和得到 y
flowchart LR
    subgraph "行并行:输入切分,输出求和"
        X0["x0: [B,S,3584]
输入前半"] --> GPU0["GPU0: W的前半行
[3584, 7168]"] X1["x1: [B,S,3584]
输入后半"] --> GPU1["GPU1: W的后半行
[3584, 7168]"] GPU0 --> Y0["y0: [B,S,7168]
(部分和)"] GPU1 --> Y1["y1: [B,S,7168]
(部分和)"] Y0 --> AR["all_reduce
求和"] Y1 --> AR AR --> Y["y: [B,S,7168]
完整结果"] end

特点

  • ⚠️ 输入必须是切片(通常来自上一层的列并行)
  • ⚠️ 输出需要 all_reduce 通信
  • ✅ 结果是完整的,可以直接用

组合使用的例子:MLP 层

1MLP 有 3 个矩阵:W1, W2, W3
2
3设计思路:
4- W1 和 W3 用列并行 → 输出是切片
5- W2 用行并行 → 输入正好是切片,输出是完整结果
6
7这样刚好"接龙",只需要在 W2 后做一次 all_reduce!
flowchart LR
    X["x: 完整"] --> W1["W1: 列并行"]
    X --> W3["W3: 列并行"]
    W1 --> SILU["SiLU"]
    SILU --> MUL["×"]
    W3 --> MUL
    MUL --> W2["W2: 行并行
(含 all_reduce)"] W2 --> Y["y: 完整"]

2.3.2 linear 函数详解

 1def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None,
 2           scale_fmt: Optional[str] = None) -> torch.Tensor:
 3    assert bias is None  # 不支持 bias
 4
 5    if weight.dtype != torch.float8_e4m3fn:
 6        # BF16 权重: 使用标准 F.linear
 7        return F.linear(x, weight)
 8    else:
 9        # FP8 权重: 量化输入后使用 FP8 GEMM
10        x, scale = act_quant(x, block_size, scale_fmt)  # 量化激活
11        return fp8_gemm(x, scale, weight, weight.scale)  # FP8 矩阵乘

通俗解释

  • BF16 模式:直接调用 PyTorch 的标准矩阵乘法
  • FP8 模式
    1. 先把输入 x 从 BF16 “压缩"成 FP8(节省显存和带宽)
    2. 用专门的 FP8 矩阵乘法 Kernel(Tensor Core 加速)
    3. 结果自动转回 BF16

FP8 vs BF16 计算流程:

flowchart TB
    subgraph "BF16 模式(简单直接)"
        X1[x: BF16] --> Linear1[F.linear]
        W1[weight: BF16] --> Linear1
        Linear1 --> Y1[y: BF16]
    end

    subgraph "FP8 模式(压缩后计算)"
        X2[x: BF16] --> Quant[act_quant
量化] Quant --> X_FP8[x: FP8
1字节/数] Quant --> X_Scale[x_scale: FP32
缩放因子] X_FP8 --> GEMM[fp8_gemm
FP8矩阵乘] X_Scale --> GEMM W2[weight: FP8] --> GEMM W_Scale[weight.scale: FP32] --> GEMM GEMM --> Y2[y: BF16] end

2.3.3 Linear 类

 1class Linear(nn.Module):
 2    dtype = torch.bfloat16      # 类级别默认精度
 3    scale_fmt: Optional[str] = None
 4
 5    def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype=None):
 6        super().__init__()
 7        self.in_features = in_features
 8        self.out_features = out_features
 9
10        # 创建权重参数
11        self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype))
12
13        # FP8 模式下需要额外的 scale 参数
14        if self.weight.element_size() == 1:  # FP8 每个元素 1 字节
15            scale_out = (out_features + block_size - 1) // block_size  # 向上取整
16            scale_in = (in_features + block_size - 1) // block_size
17            # scale 按 128×128 块存储
18            self.weight.scale = self.scale = nn.Parameter(
19                torch.empty(scale_out, scale_in, dtype=torch.float32)
20            )
21        else:
22            self.register_parameter("scale", None)

FP8 权重与 Scale 的关系:

flowchart TB
    subgraph "权重矩阵 (7168 × 1536)"
        W[weight: FP8
shape: 7168×1536] end subgraph "Scale 矩阵 (56 × 12)" S["scale: FP32
shape: ceil(7168/128) × ceil(1536/128)
= 56 × 12"] end subgraph "对应关系" B1["Block (0,0)
128×128"] --> S1["scale[0,0]"] B2["Block (0,1)
128×128"] --> S2["scale[0,1]"] BN["Block (55,11)
128×128"] --> SN["scale[55,11]"] end

2.3.3 ColumnParallelLinear (列并行)

作用: 将输出维度切分到多个 GPU。

flowchart LR
    subgraph "ColumnParallelLinear"
        X[x: B×S×D] --> GPU0["GPU0: W[0:N/8, :]"]
        X --> GPU1["GPU1: W[N/8:2N/8, :]"]
        X --> GPU7["GPU7: W[7N/8:N, :]"]
        GPU0 --> Y0["y0: B×S×N/8"]
        GPU1 --> Y1["y1: B×S×N/8"]
        GPU7 --> Y7["y7: B×S×N/8"]
    end
1class ColumnParallelLinear(Linear):
2    def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype=None):
3        assert out_features % world_size == 0
4        self.part_out_features = out_features // world_size  # 输出维度切分
5        super().__init__(in_features, self.part_out_features, bias, dtype)
6
7    def forward(self, x: torch.Tensor) -> torch.Tensor:
8        # 直接计算,输出是完整输出的一部分
9        return linear(x, self.weight, self.bias, self.scale_fmt)

2.3.4 RowParallelLinear (行并行)

作用: 将输入维度切分到多个 GPU,计算后需要 all_reduce 合并。

flowchart LR
    subgraph "RowParallelLinear"
        X0["x0: B×S×D/8"] --> GPU0["GPU0: W[:, 0:D/8]"]
        X1["x1: B×S×D/8"] --> GPU1["GPU1: W[:, D/8:2D/8]"]
        X7["x7: B×S×D/8"] --> GPU7["GPU7: W[:, 7D/8:D]"]
        GPU0 --> Y0["y0: B×S×N"]
        GPU1 --> Y1["y1: B×S×N"]
        GPU7 --> Y7["y7: B×S×N"]
        Y0 --> AR[all_reduce]
        Y1 --> AR
        Y7 --> AR
        AR --> Y["y: B×S×N"]
    end
 1class RowParallelLinear(Linear):
 2    def __init__(self, in_features: int, out_features: int, bias: bool = False,
 3                 reduce_output=True, dtype=None):
 4        assert in_features % world_size == 0
 5        self.part_in_features = in_features // world_size  # 输入维度切分
 6        self.reduce_output = reduce_output
 7        super().__init__(self.part_in_features, out_features, bias, dtype)
 8
 9    def forward(self, x: torch.Tensor) -> torch.Tensor:
10        y = linear(x, self.weight, None, self.scale_fmt)
11
12        if self.reduce_output and world_size > 1:
13            y = y.float()
14            dist.all_reduce(y)  # 关键: 合并所有 GPU 的结果
15
16        if self.bias is not None:
17            y += self.bias
18
19        return y.type_as(x)

2.3.5 并行策略对比

类型 切分维度 通信操作 适用场景
ColumnParallel 输出 (dim=0) 无 (输出直接使用) 后接 RowParallel
RowParallel 输入 (dim=1) all_reduce 后接 ColumnParallel 或输出

2.4 RMSNorm 与 LayerNorm

位置: model.py:272-321

2.4.1 通俗解释:给向量"校准刻度”

为什么需要归一化?

想象你有一组测量数据:[0.001, 0.002, 0.003] 和另一组 [1000, 2000, 3000]。虽然它们的"相对关系"一样,但数值差了百万倍!

神经网络在训练/推理时,不同层输出的数值范围可能差异很大。如果不处理:

  • 有些层输出 0.001,有些输出 10000
  • 后面的计算会被大数值"淹没"
  • 梯度可能爆炸或消失

归一化的作用:把每个向量"校准"到统一的刻度,通常让数值在 -1 到 1 附近。


RMSNorm vs LayerNorm:两种"校准"方式

方法 LayerNorm RMSNorm
步骤 1. 减去均值 2. 除以标准差 只除以 RMS(均方根)
公式 $\frac{x - \mu}{\sigma}$ $\frac{x}{\text{RMS}(x)}$
参数 weight + bias 只有 weight
计算量 较多 较少
效果 更"中心化" 保持原始分布形状

通俗比喻

  • LayerNorm = “把考试成绩换算成标准分”(减均值、除标准差)
  • RMSNorm = “把成绩按总分折算成百分比”(只做缩放,不平移)

具体例子:假设输入向量 x = [3, 4]

 1RMSNorm 计算过程:
 2
 3Step 1: 计算平方
 4        x² = [9, 16]
 5
 6Step 2: 计算均方 (Mean of Squares)
 7        mean(x²) = (9 + 16) / 2 = 12.5
 8
 9Step 3: 计算均方根 (Root Mean Square)
10        RMS = √12.5 ≈ 3.54
11
12Step 4: 归一化
13        y = x / RMS = [3/3.54, 4/3.54] ≈ [0.85, 1.13]
14
15Step 5: 乘以可学习权重 γ
16        y = y × γ = [0.85×γ₁, 1.13×γ₂]
flowchart LR
    X["x = [3, 4]"] --> Square["x² = [9, 16]"]
    Square --> Mean["mean = 12.5"]
    Mean --> Add["+ ε = 12.500001"]
    Add --> Rsqrt["1/√... ≈ 0.283"]
    X --> Mul1["×"]
    Rsqrt --> Mul1
    Mul1 --> Result["[0.85, 1.13]"]
    Result --> Mul2["× γ"]
    Mul2 --> Y["最终输出"]

2.4.2 RMSNorm 公式与源码

公式: $$\text{RMSNorm}(x) = \frac{x}{\sqrt{\text{mean}(x^2) + \epsilon}} \cdot \gamma$$

源码解析:

 1class RMSNorm(nn.Module):
 2    def __init__(self, dim: int, eps: float = 1e-6):
 3        super().__init__()
 4        self.dim = dim
 5        self.eps = eps
 6        # 注意: weight 是 FP32 精度
 7        self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
 8
 9    def forward(self, x: torch.Tensor, residual: Optional[torch.Tensor] = None):
10        dtype = x.dtype
11
12        if residual is None:
13            # 模式1: 普通归一化
14            x = x.float()
15            var = x.pow(2).mean(-1, keepdim=True)  # 计算方差 (实际是均方)
16            x = x * torch.rsqrt(var + self.eps)    # 归一化
17            return (self.weight * x).to(dtype)
18        else:
19            # 模式2: 融合残差的归一化 (优化显存)
20            x = residual = x.float() + residual.float()  # 先加残差
21            var = x.pow(2).mean(-1, keepdim=True)
22            x = x * torch.rsqrt(var + self.eps)
23            return (self.weight * x).to(dtype), residual.to(dtype)

2.4.2 LayerNorm

 1class LayerNorm(nn.Module):
 2    def __init__(self, dim: int, eps: float = 1e-6):
 3        super().__init__()
 4        self.dim = dim
 5        self.eps = eps
 6        self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
 7        self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))  # 比 RMSNorm 多一个 bias
 8
 9    def forward(self, x: torch.Tensor):
10        return F.layer_norm(x.float(), (self.dim,), self.weight, self.bias, self.eps).type_as(x)

2.4.3 RMSNorm vs LayerNorm

特性 RMSNorm LayerNorm
公式 $\frac{x}{\text{RMS}(x)} \cdot \gamma$ $\frac{x - \mu}{\sigma} \cdot \gamma + \beta$
参数 weight weight + bias
计算量 更少 (无需计算均值) 更多
使用场景 主干网络 Indexer 的 k_norm

3. 位置编码详解

3.0 通俗解释:让模型知道"谁在前谁在后"

为什么需要位置编码?

看这两句话:

  • “狗咬了人”
  • “人咬了狗”

词完全一样,但意思完全不同!区别在于词的顺序

Transformer 的注意力机制本身是"位置无关"的——它只看词与词之间的关系,不知道谁在前谁在后。所以我们需要给每个位置一个"身份标签"。


传统方法 vs RoPE:两种"标签"方式

方法 传统位置编码 RoPE (旋转位置编码)
思路 给每个位置加一个固定向量 根据位置旋转向量
类比 给每个学生发一个号码牌 让每个学生转不同的角度
优点 简单直观 可以表示相对位置,支持长度外推

RoPE 的核心思想:用"旋转"表示位置

想象一个时钟:

  • 位置 0 的向量:指向 12 点方向
  • 位置 1 的向量:顺时针转一点
  • 位置 2 的向量:再转一点
1位置 0:  ↑     位置 1:  ↗     位置 2:  →     位置 3:  ↘
2         |              /              —              \

关键特性:两个位置的"相对距离"可以通过旋转角度的差来表示!

比如位置 5 和位置 3 的相对距离是 2,无论它们在序列的哪里,旋转角度差都一样。


为什么用复数?

复数乘法天然就是旋转!

1复数 z = a + bi 可以表示为 (长度, 角度)
2复数乘法 z1 × z2 = (长度1×长度2, 角度1+角度2)

所以 向量 × e^(iθ) 就是把向量旋转 θ 角度。


具体例子:假设向量 x = [1, 0],位置 m=3,频率 θ=30°

 1Step 1: 把向量看作复数
 2        x = 1 + 0i
 3
 4Step 2: 计算旋转角度
 5        角度 = m × θ = 3 × 30° = 90°
 6
 7Step 3: 计算旋转因子
 8        e^(i×90°) = cos(90°) + i×sin(90°) = 0 + 1i
 9
10Step 4: 复数乘法
11        x' = (1 + 0i) × (0 + 1i) = 0 + 1i
12
13Step 5: 转回向量
14        x' = [0, 1]  (从指向右变成指向上,旋转了 90°)
flowchart LR
    subgraph "RoPE 旋转示意"
        X["原始向量 [1,0]
→"] --> Rotate["× e^(i×90°)"] Rotate --> Y["旋转后 [0,1]
↑"] end

3.1 precompute_freqs_cis 预计算

位置: model.py:324-402

作用: 预计算所有位置的旋转频率复数表示。

为什么要"预计算"?

  • 每个位置的旋转角度是固定的(只跟位置有关)
  • 提前算好存起来,推理时直接查表,省时间
flowchart TB
    subgraph "频率计算"
        Base["base = 10000"] --> Freqs["freqs = 1/(base^(2i/d))"]
        Freqs --> Check{"seq_len > original?"}
        Check -->|Yes| YaRN["YaRN 调整频率"]
        Check -->|No| Direct["直接使用"]
    end

    subgraph "复数表示"
        T["positions = [0,1,2,...,seq_len-1]"] --> Outer["outer(positions, freqs)"]
        Outer --> Polar["polar(1, angles)"]
        Polar --> FreqsCis["freqs_cis: 复数 e^(iθ)"]
    end

源码逐行解析:

 1def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
 2    dim = args.qk_rope_head_dim  # 64
 3    seqlen = args.max_seq_len    # 163840
 4    base = args.rope_theta       # 10000
 5    factor = args.rope_factor    # 40
 6
 7    # 计算基础频率: θ_i = 1 / (10000^(2i/64))
 8    # i ∈ [0, 2, 4, ..., 62], 共 32 个
 9    freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
10    # freqs.shape = [32], 值从 1.0 衰减到 约 0.01
11
12    # 如果需要外推 (163840 > 4096)
13    if seqlen > args.original_seq_len:
14        # YaRN 调整 (详见 3.3 节)
15        low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len)
16        smooth = 1 - linear_ramp_factor(low, high, dim // 2)
17        freqs = freqs / factor * (1 - smooth) + freqs * smooth
18
19    # 计算每个位置的角度
20    t = torch.arange(seqlen)  # [0, 1, 2, ..., 163839]
21    freqs = torch.outer(t, freqs)  # [163840, 32] - 外积
22
23    # 转换为复数表示: e^(iθ) = cos(θ) + i*sin(θ)
24    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
25    # freqs_cis.shape = [163840, 32], dtype = complex64
26
27    return freqs_cis

3.2 apply_rotary_emb 应用旋转

位置: model.py:405-425

RoPE 原理:

将向量看作复数,通过乘以 $e^{i m \theta}$ 实现位置相关的旋转:

$$ \text{RoPE}(x, m) = x \cdot e^{im\theta} $$

其中 $m$ 是位置,$\theta$ 是频率。

flowchart TB
    subgraph "apply_rotary_emb"
        X["x: [B,S,H,D]
D=64 (qk_rope_head_dim)"] --> View["view_as_complex
[B,S,H,32] complex"] FreqsCis["freqs_cis: [S,32] complex"] --> Broadcast["broadcast to [1,S,1,32]"] View --> Mul["x × freqs_cis"] Broadcast --> Mul Mul --> Real["view_as_real
[B,S,H,32,2]"] Real --> Flatten["flatten → [B,S,H,64]"] end

源码逐行解析:

 1def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor, interleaved: bool = True) -> torch.Tensor:
 2    dtype = x.dtype
 3    shape = x.shape  # [B, S, H, D]
 4
 5    if not interleaved:
 6        # 非交错模式: [x0,x1,x2,...,x31,x32,...,x63] → [x0,x32,x1,x33,...,x31,x63]
 7        x = x.view(*shape[:-1], 2, -1).transpose(-1, -2).contiguous()
 8
 9    # 将实数对看作复数: [B,S,H,D] → [B,S,H,D/2] complex
10    x = torch.view_as_complex(x.float().view(*shape[:-1], -1, 2))
11
12    # 广播 freqs_cis 到 [1, S, 1, D/2]
13    freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
14
15    # 复数乘法实现旋转
16    y = torch.view_as_real(x * freqs_cis).flatten(3)  # [B,S,H,D]
17
18    if not interleaved:
19        # 恢复非交错布局
20        y = torch.cat([y[..., 0::2], y[..., 1::2]], dim=-1)
21
22    return y.to(dtype)

3.3 YaRN 长度外推

问题: 模型在 4096 长度上训练,如何外推到 163840?

解决方案: YaRN (Yet another RoPE extensioN) 对不同频率的维度采用不同的缩放策略。

flowchart TB
    subgraph "YaRN 策略"
        Low["低频维度 (i < low)
线性插值: freqs / factor"] High["高频维度 (i > high)
保持不变: freqs"] Mid["中间维度 (low ≤ i ≤ high)
平滑过渡"] end subgraph "参数含义" BetaFast["beta_fast=32
高频旋转次数阈值"] BetaSlow["beta_slow=1
低频旋转次数阈值"] Factor["factor=40
外推倍数"] end

源码中的辅助函数:

 1def find_correction_dim(num_rotations, dim, base, max_seq_len):
 2    """计算给定旋转次数对应的维度索引"""
 3    return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
 4
 5def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
 6    """计算需要调整的维度范围"""
 7    low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
 8    high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
 9    return max(low, 0), min(high, dim-1)
10
11def linear_ramp_factor(min, max, dim):
12    """生成平滑过渡的线性斜坡"""
13    linear_func = (torch.arange(dim) - min) / (max - min)
14    return torch.clamp(linear_func, 0, 1)

3.3.1 mscale:额外注意力缩放(实现细节)

除 RoPE 频率缩放外,MLA 里还会在“外推场景”额外调节 softmax 的缩放系数(只在 max_seq_len > original_seq_len 时生效):

1if args.max_seq_len > args.original_seq_len:
2    mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
3    self.softmax_scale = self.softmax_scale * mscale * mscale

这意味着在长上下文外推时,注意力的温度会随 rope_factor/mscale 一起被重新校准;对应参数在 inference/config_671B_v3.2.json 中提供。


4. MLA 注意力机制

4.0 通俗解释:压缩记忆的艺术

问题:KV Cache 太大了!

传统注意力机制需要缓存所有历史 token 的 K 和 V 向量。对于 DeepSeek-V3.2:

1传统 KV Cache 大小计算:
2- 序列长度: 163840 tokens
3- K 维度: 128 头 × 192 维 = 24576
4- V 维度: 128 头 × 128 维 = 16384
5- 每层缓存: 163840 × (24576 + 16384) × 2B ≈ 13.4 GB
6- 61 层总共: 13.4 × 61 ≈ 817 GB  😱

这显然不可接受!


MLA 的核心思想:先压缩再缓存

想象你要记笔记:

  • 传统方法:把老师说的每句话都逐字记下来 → 笔记本爆炸
  • MLA 方法:先把老师的话"压缩"成关键词,只记关键词 → 需要时再"展开"
1传统:x → K (24576维) → 缓存 K
2MLA: x → c_KV (512维) → 缓存 c_KV → 展开成 K (24576维)
3
4压缩率: 512 / 24576 ≈ 2%!

MLA Cache 大小计算

1MLA KV Cache 大小:
2- 序列长度: 163840 tokens
3- 压缩后维度: 512 + 64 (RoPE) = 576
4- 每层缓存: 163840 × 576 × 2B ≈ 188 MB
5- 61 层总共: 188 × 61 ≈ 11.5 GB  ✅
6
7节省: 817 GB → 11.5 GB,减少 98.6%!

MLA 的三个关键概念

概念 作用 类比
低秩压缩 用小矩阵近似大矩阵 用"摘要"代替"全文"
nope + rope 分离 把向量分成两部分处理 有些信息需要位置,有些不需要
延迟展开 缓存压缩版,用时再展开 压缩包存硬盘,用时解压

Q/K/V 的压缩与展开流程

flowchart TB
    subgraph "Query 路径"
        X1["x: 7168维"] --> WQA["wq_a: 压缩
7168→1536"] WQA --> QNorm["RMSNorm"] QNorm --> WQB["wq_b: 展开
1536→24576"] WQB --> Q["Q: 128头×192维"] end subgraph "KV 路径 (关键!)" X2["x: 7168维"] --> WKVA["wkv_a: 压缩
7168→576"] WKVA --> Split["分离"] Split --> CKV["c_kv: 512维
(缓存这个!)"] Split --> KPE["k_pe: 64维
(位置编码)"] CKV --> KVNorm["RMSNorm"] KVNorm --> Cache["💾 存入缓存"] KVNorm --> WKVB["wkv_b: 展开
512→32768"] WKVB --> KV["K_nope + V"] end

为什么叫"Multi-head Latent Attention"?

  • Multi-head: 多头注意力,每个头看不同的"视角"
  • Latent: “潜在的”,指那个压缩后的小向量 c_KV
  • Attention: 注意力机制

一句话总结:MLA 把 KV 压缩成一个"潜在向量"缓存起来,用的时候再展开,节省 98% 的显存。


4.1 MLA 类结构

位置: model.py:498-608

MLA vs 传统 MHA 对比:

graph LR
    subgraph "传统 MHA"
        X1[x] --> WQ["W_Q: D→H×D_h"]
        X1 --> WK["W_K: D→H×D_h"]
        X1 --> WV["W_V: D→H×D_h"]
        WQ --> Q1[Q]
        WK --> K1["K (缓存)"]
        WV --> V1["V (缓存)"]
    end

    subgraph "MLA"
        X2[x] --> WQA["W_QA: D→R_q"]
        WQA --> QNorm[RMSNorm]
        QNorm --> WQB["W_QB: R_q→H×D_h"]
        WQB --> Q2[Q]

        X2 --> WKVA["W_KVA: D→R_kv+D_rope"]
        WKVA --> KVNorm[RMSNorm]
        KVNorm --> C_KV["c_KV (只缓存这个!)"]
        C_KV --> WKVB["W_KVB: R_kv→H×(D_nope+D_v)"]
        WKVB --> K2[K]
        WKVB --> V2[V]
    end

类初始化:

 1class MLA(nn.Module):
 2    def __init__(self, args: ModelArgs):
 3        super().__init__()
 4        self.dim = args.dim                          # 7168
 5        self.n_heads = args.n_heads                  # 128
 6        self.n_local_heads = args.n_heads // world_size  # 128/8=16 (每 GPU)
 7        self.q_lora_rank = args.q_lora_rank          # 1536
 8        self.kv_lora_rank = args.kv_lora_rank        # 512
 9        self.qk_nope_head_dim = args.qk_nope_head_dim  # 128
10        self.qk_rope_head_dim = args.qk_rope_head_dim  # 64
11        self.qk_head_dim = 128 + 64                  # 192
12        self.v_head_dim = args.v_head_dim            # 128
13
14        # === Query 路径 ===
15        self.wq_a = Linear(7168, 1536)               # 压缩
16        self.q_norm = RMSNorm(1536)
17        self.wq_b = ColumnParallelLinear(1536, 128*192)  # 展开 (列并行)
18
19        # === KV 路径 ===
20        self.wkv_a = Linear(7168, 512+64)            # 压缩 (512 for KV, 64 for RoPE)
21        self.kv_norm = RMSNorm(512)
22        self.wkv_b = ColumnParallelLinear(512, 128*(128+128))  # 展开 (k_nope + v)
23
24        # === 输出 ===
25        self.wo = RowParallelLinear(128*128, 7168)   # 行并行
26
27        # === Indexer ===
28        self.indexer = Indexer(args)
29
30        # === KV Cache ===
31        self.register_buffer("kv_cache", torch.zeros(8, 163840, 512))  # 只缓存压缩的 c_KV
32        self.register_buffer("pe_cache", torch.zeros(8, 163840, 64))   # 缓存 k_pe

4.2 Query 路径详解

flowchart TB
    subgraph "Query 计算"
        X["x: [B,S,7168]"] --> WQA["wq_a
Linear(7168→1536)"] WQA --> C_Q["c_q: [B,S,1536]"] C_Q --> QNorm["q_norm
RMSNorm"] QNorm --> QR["qr: [B,S,1536]
(也传给 Indexer)"] QR --> WQB["wq_b
ColumnParallel(1536→128×192/8)"] WQB --> Q["q: [B,S,16,192]
(每 GPU 16 头)"] Q --> Split["split(128, 64)"] Split --> Q_NOPE["q_nope: [B,S,16,128]"] Split --> Q_PE["q_pe: [B,S,16,64]"] Q_PE --> RoPE["apply_rotary_emb"] RoPE --> Q_PE_ROT["q_pe (旋转后)"] end

源码:

 1# 在 forward() 中
 2qr = self.q_norm(self.wq_a(x))  # [B,S,1536] - 压缩 + 归一化
 3q = self.wq_b(qr)               # [B,S,16*192] - 展开 (列并行后每 GPU 16 头)
 4q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)  # [B,S,16,192]
 5
 6# 分离 nope 和 rope 部分
 7q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
 8# q_nope: [B,S,16,128], q_pe: [B,S,16,64]
 9
10q_pe = apply_rotary_emb(q_pe, freqs_cis)  # 应用位置编码

4.3 KV 路径详解

flowchart TB
    subgraph "KV 计算"
        X["x: [B,S,7168]"] --> WKVA["wkv_a
Linear(7168→576)"] WKVA --> KV_PE["[B,S,576]"] KV_PE --> Split1["split(512, 64)"] Split1 --> KV["kv: [B,S,512]"] Split1 --> K_PE["k_pe: [B,S,64]"] KV --> KVNorm["kv_norm
RMSNorm"] KVNorm --> KV_Norm["kv (归一化)"] KV_Norm --> FP8["FP8 量化模拟"] FP8 --> Cache1["存入 kv_cache"] K_PE --> RoPE2["apply_rotary_emb"] RoPE2 --> K_PE_ROT["k_pe (旋转后)"] K_PE_ROT --> Cache2["存入 pe_cache"] KV_Norm --> WKVB["wkv_b
ColumnParallel(512→128×256/8)"] WKVB --> KV_Expand["[B,S,16,256]"] KV_Expand --> Split2["split(128, 128)"] Split2 --> K_NOPE["k_nope: [B,S,16,128]"] Split2 --> V["v: [B,S,16,128]"] end

源码:

 1# 在 forward() 中
 2kv = self.wkv_a(x)  # [B,S,576] - 压缩
 3kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
 4# kv: [B,S,512], k_pe: [B,S,64]
 5
 6kv = self.kv_norm(kv)  # 归一化
 7k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)  # 位置编码
 8
 9# FP8 量化模拟 (为了与部署一致)
10kv_fp8, kv_scale = act_quant(kv, block_size, self.scale_fmt)
11kv = (kv_fp8.view(-1, block_size).float() * kv_scale.view(-1, 1)).to(kv.dtype).view_as(kv)
12
13# 存入缓存
14self.kv_cache[:bsz, start_pos:end_pos] = kv      # 只存 512 维!
15self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)

4.4 Prefill vs Decode 模式

4.4.1 Prefill 模式 (处理整个 prompt)

条件: mask is not None (seqlen > 1)

flowchart TB
    subgraph "Prefill 模式"
        Q["Q: [B,S,16,192]"] --> QK["Q @ K^T"]
        K["K: [B,S,16,192]"] --> QK
        QK --> Scale["× softmax_scale"]
        Scale --> Indexer["+ Indexer Mask"]
        Indexer --> CausalMask["+ Causal Mask"]
        CausalMask --> Softmax["softmax"]
        Softmax --> AV["@ V"]
        V["V: [B,S,16,128]"] --> AV
        AV --> Output["[B,S,16,128]"]
    end

源码:

 1if mask is not None:  # Prefill
 2    q = torch.cat([q_nope, q_pe], dim=-1)  # [B,S,16,192]
 3
 4    # 展开 KV
 5    kv = self.wkv_b(kv)  # [B,S,16*256]
 6    kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
 7    k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
 8    k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
 9
10    # 注意力计算
11    scores = torch.einsum("bshd,bthd->bsht", q, k).mul_(self.softmax_scale)
12
13    # Indexer 稀疏化
14    topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
15    index_mask = torch.full((bsz, seqlen, seqlen), float("-inf"), device=x.device)
16    index_mask = index_mask.scatter_(-1, topk_indices, 0)  # 只保留 top-k
17    scores += index_mask.unsqueeze(2) + mask
18
19    scores = scores.softmax(dim=-1)
20    x = torch.einsum("bsht,bthd->bshd", scores, v)

4.4.2 Decode 模式 (逐 token 生成)

条件: mask is None (seqlen == 1)

关键优化: 不展开 K,而是先用 Q_nope 与 W_kv_b 合并,再与缓存的 c_KV 计算。

flowchart TB
    subgraph "Decode 模式"
        Q_NOPE["q_nope: [B,1,16,128]"] --> EINSUM1["einsum(bshd,hdc→bshc)"]
        WKV_B["wkv_b[:, :128, :]: [16,128,512]"] --> EINSUM1
        EINSUM1 --> Q_NOPE_PROJ["[B,1,16,512]"]
        Q_NOPE_PROJ --> EINSUM2["einsum(bshc,btc→bsht)"]
        KV_CACHE["kv_cache: [B,T,512]"] --> EINSUM2
        EINSUM2 --> Score1["score1"]

        Q_PE["q_pe: [B,1,16,64]"] --> EINSUM3["einsum(bshr,btr→bsht)"]
        PE_CACHE["pe_cache: [B,T,64]"] --> EINSUM3
        EINSUM3 --> Score2["score2"]

        Score1 --> Add["+ × scale"]
        Score2 --> Add
        Add --> IndexMask["+ Indexer Mask"]
        IndexMask --> Softmax["softmax"]
        Softmax --> EINSUM4["einsum(bsht,btc→bshc)"]
        KV_CACHE --> EINSUM4
        EINSUM4 --> Out1["[B,1,16,512]"]
        Out1 --> EINSUM5["einsum(bshc,hdc→bshd)"]
        WKV_B2["wkv_b[:, -128:, :]: [16,128,512]"] --> EINSUM5
        EINSUM5 --> Output["[B,1,16,128]"]
    end

源码:

 1else:  # Decode
 2    # 延迟反量化 wkv_b (只做一次)
 3    if self.dequant_wkv_b is None and self.wkv_b.scale is not None:
 4        self.dequant_wkv_b = weight_dequant(self.wkv_b.weight, self.wkv_b.scale)
 5    wkv_b = self.wkv_b.weight if self.dequant_wkv_b is None else self.dequant_wkv_b
 6    wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)  # [16, 256, 512]
 7
 8    # 关键优化: Q_nope 先与 W_kv_b 的 K 部分结合
 9    q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
10    # [B,1,16,128] @ [16,128,512] → [B,1,16,512]
11
12    # 与缓存计算注意力分数
13    scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
14              torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
15
16    # Indexer 稀疏化
17    topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
18    index_mask = torch.full((bsz, 1, end_pos), float("-inf"), device=x.device)
19    index_mask = index_mask.scatter_(-1, topk_indices, 0)
20    scores += index_mask.unsqueeze(2)
21
22    scores = scores.softmax(dim=-1)
23
24    # 计算输出
25    x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
26    x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])

5. Indexer 稀疏注意力

5.0 通俗解释:精选"关键词"的搜索引擎

问题:注意力太"民主"了!

传统注意力机制会让每个 token 关注所有历史 token。对于 163840 长度的上下文:

1传统注意力计算量:
2- 每个新 token 要和 163840 个历史 token 计算相似度
3- 计算量 = 163840 × 头数 × 维度 ≈ 非常大!

但实际上,大部分历史 token 跟当前问题根本没关系


Indexer 的核心思想:先"粗筛"再"精选"

想象你在图书馆找资料:

  • 传统方法:把图书馆所有书都翻一遍 → 太慢!
  • Indexer 方法
    1. 先用"关键词"快速搜索(粗筛)
    2. 只拿出最相关的 2048 本书(精选)
    3. 仔细阅读这 2048 本(精细注意力)
flowchart LR
    subgraph "Indexer 工作流程"
        All["163840 个历史 token"] --> Coarse["粗筛:快速评分"]
        Coarse --> TopK["选出 top-2048"]
        TopK --> Fine["精细注意力"]
        Fine --> Output["最终结果"]
    end

为什么选 2048?

选择数量 效果 计算量
全部 (163840) 最准确 100%
2048 几乎不损失 1.25%
512 有损失 0.3%

实验发现 2048 是个甜蜜点:计算量降到 1.25%,效果几乎不变!


Indexer 的评分公式

1对于当前 token 的 query q 和历史 token 的 key k:
2
3index_score = Σ (ReLU(q·k) × weight)
4
5- q·k: 相似度分数
6- ReLU: 只保留正分数(负相关的直接忽略)
7- weight: 每个注意力头的权重
8- Σ: 对所有头求和

直觉理解

  • 正分数 = “这本书可能有用” → 保留
  • 负/零分数 = “这本书肯定没用” → 忽略
  • weight = “这个头的判断有多可靠”

Hadamard 变换:让量化更准确

Indexer 使用 FP8 量化来加速计算,但 FP8 对数值分布有要求。Hadamard 变换就是一个"搅拌"操作:

1原始向量: [1, 0, 0, 0, ...]   ← 很多 0,不利于量化
2变换后:   [0.1, 0.1, 0.1, 0.1, ...] ← 值更均匀,量化更准
flowchart LR
    X["原始向量
[1,0,0,0]"] --> H["Hadamard
变换"] H --> Y["均匀化向量
[0.5,0.5,0.5,0.5]"] Y --> Q["FP8 量化"] Q --> GEMM["快速矩阵乘"]

与 MLA 的配合

flowchart TB
    subgraph "MLA 主路径"
        X["输入 x"] --> MLA_Q["生成 Q"]
        X --> MLA_KV["生成 KV Cache"]
    end

    subgraph "Indexer 辅助路径"
        X --> IDX_K["生成 Index K"]
        MLA_Q --> QR["qr (压缩的Q)"]
        QR --> IDX_Q["生成 Index Q"]
        IDX_Q --> Score["计算 Index Score"]
        IDX_K --> Score
        Score --> TopK["选出 top-2048"]
    end

    TopK --> Mask["生成稀疏 Mask"]
    MLA_KV --> Attn["稀疏注意力"]
    Mask --> Attn
    Attn --> Output["输出"]

一句话总结:Indexer 就是注意力机制的"搜索引擎",先快速筛选出最相关的 2048 个 token,再做精细注意力。


5.1 Indexer 类结构

位置: model.py:435-487

作用: 从所有历史 token 中选择最相关的 top-k (2048) 个进行精细注意力。

 1class Indexer(torch.nn.Module):
 2    def __init__(self, args: ModelArgs):
 3        super().__init__()
 4        self.dim = args.dim                    # 7168
 5        self.n_heads = args.index_n_heads      # 64
 6        self.n_local_heads = 64 // world_size  # 8 (每 GPU)
 7        self.head_dim = args.index_head_dim    # 128
 8        self.rope_head_dim = args.qk_rope_head_dim  # 64
 9        self.index_topk = args.index_topk      # 2048
10        self.q_lora_rank = args.q_lora_rank    # 1536
11
12        # Indexer 自己的投影层
13        self.wq_b = Linear(1536, 64*128)       # Q 投影 (使用 MLA 的 qr)
14        self.wk = Linear(7168, 128)            # K 投影 (从原始 x)
15        self.k_norm = LayerNorm(128)           # K 归一化
16        self.weights_proj = Linear(7168, 64, dtype=torch.float32)  # 权重投影
17
18        # Indexer 的 K Cache (FP8)
19        self.register_buffer("k_cache", torch.zeros(8, 163840, 128, dtype=torch.float8_e4m3fn))
20        self.register_buffer("k_scale_cache", torch.zeros(8, 163840, 1))  # 128/128=1 个 scale

5.2 前向传播流程

flowchart TB
    subgraph "Indexer Forward"
        X["x: [B,S,7168]"] --> WK["wk(x)"]
        WK --> K_Norm["k_norm"]
        K_Norm --> K["k: [B,S,128]"]
        K --> K_Split["split(64, 64)"]
        K_Split --> K_PE["k_pe: 64"]
        K_Split --> K_NOPE["k_nope: 64"]
        K_PE --> RoPE_K["RoPE (non-interleaved)"]
        RoPE_K --> K_Cat["concat"]
        K_NOPE --> K_Cat
        K_Cat --> Hadamard_K["Hadamard 变换"]
        Hadamard_K --> Quant_K["FP8 量化"]
        Quant_K --> K_Cache["存入 k_cache"]

        QR["qr: [B,S,1536]
(from MLA)"] --> WQ_B["wq_b(qr)"] WQ_B --> Q["q: [B,S,64,128]"] Q --> Q_Split["split(64, 64)"] Q_Split --> Q_PE["q_pe: 64"] Q_Split --> Q_NOPE["q_nope: 64"] Q_PE --> RoPE_Q["RoPE (non-interleaved)"] RoPE_Q --> Q_Cat["concat"] Q_NOPE --> Q_Cat Q_Cat --> Hadamard_Q["Hadamard 变换"] Hadamard_Q --> Quant_Q["FP8 量化"] Quant_Q --> FP8_Index["fp8_index Kernel"] K_Cache --> FP8_Index X --> Weights_Proj["weights_proj(x)"] Weights_Proj --> Weights["weights: [B,S,64]"] Weights --> FP8_Index FP8_Index --> Scores["index_scores: [B,S,T]"] Scores --> Mask["+ causal mask"] Mask --> TopK["top-k(2048)"] TopK --> Indices["topk_indices: [B,S,2048]"] end

源码逐行解析:

 1def forward(self, x, qr, start_pos, freqs_cis, mask):
 2    bsz, seqlen, _ = x.size()
 3    end_pos = start_pos + seqlen
 4
 5    # === Q 计算 ===
 6    q = self.wq_b(qr)  # [B,S,64*128]
 7    q = q.view(bsz, seqlen, self.n_heads, self.head_dim)  # [B,S,64,128]
 8    q_pe, q_nope = torch.split(q, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
 9    # 注意: Indexer 的 RoPE 是非交错的
10    q_pe = apply_rotary_emb(q_pe, freqs_cis, False)
11    q = torch.cat([q_pe, q_nope], dim=-1)
12
13    # === K 计算 ===
14    k = self.wk(x)     # [B,S,128]
15    k = self.k_norm(k)  # LayerNorm
16    k_pe, k_nope = torch.split(k, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
17    k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis, False).squeeze(2)
18    k = torch.cat([k_pe, k_nope], dim=-1)
19
20    # === Hadamard 变换 ===
21    q = rotate_activation(q)  # 随机正交变换
22    k = rotate_activation(k)
23
24    # === FP8 量化 ===
25    q_fp8, q_scale = act_quant(q, block_size, self.scale_fmt)
26    k_fp8, k_scale = act_quant(k, block_size, self.scale_fmt)
27
28    # === 更新 K Cache ===
29    self.k_cache[:bsz, start_pos:end_pos] = k_fp8
30    self.k_scale_cache[:bsz, start_pos:end_pos] = k_scale
31
32    # === 计算 Index Score ===
33    weights = self.weights_proj(x.float()) * self.n_heads ** -0.5
34    weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
35
36    # FP8 矩阵乘 + ReLU + 求和
37    index_score = fp8_index(q_fp8.contiguous(), weights,
38                            self.k_cache[:bsz, :end_pos].contiguous(),
39                            self.k_scale_cache[:bsz, :end_pos].contiguous())
40
41    if mask is not None:
42        index_score += mask  # 加上因果 mask
43
44    # === Top-K 选择 ===
45    topk_indices = index_score.topk(min(self.index_topk, end_pos), dim=-1)[1]
46
47    # 同步验证 (确保所有 GPU 选择相同的 token)
48    topk_indices_ = topk_indices.clone()
49    dist.broadcast(topk_indices_, src=0)
50    assert torch.all(topk_indices == topk_indices_)
51
52    return topk_indices

5.3 Hadamard 变换

位置: model.py:428-432

作用: 对 Q/K 进行随机正交变换,使它们的分布更均匀,有利于 FP8 量化。

1def rotate_activation(x: torch.Tensor) -> torch.Tensor:
2    assert x.dtype == torch.bfloat16
3    from fast_hadamard_transform import hadamard_transform
4    hidden_size = x.size(-1)
5    return hadamard_transform(x, scale=hidden_size ** -0.5)

Hadamard 变换:

$$H_n = \frac{1}{\sqrt{n}} \begin{pmatrix} H_{n/2} & H_{n/2} \ H_{n/2} & -H_{n/2} \end{pmatrix}$$

特点:

  • 正交变换,保持向量范数
  • 快速算法 O(n log n)
  • 使激活值分布更均匀

6. 前馈网络

6.1 MLP 结构

位置: model.py:611-643

flowchart LR
    X["x: [B,S,D]"] --> W1["w1 (gate_proj)
ColumnParallel"] X --> W3["w3 (up_proj)
ColumnParallel"] W1 --> SiLU["SiLU"] SiLU --> Mul["×"] W3 --> Mul Mul --> W2["w2 (down_proj)
RowParallel"] W2 --> Y["y: [B,S,D]"]

源码:

 1class MLP(nn.Module):
 2    def __init__(self, dim: int, inter_dim: int, reduce_output: bool = True):
 3        super().__init__()
 4        self.w1 = ColumnParallelLinear(dim, inter_dim)  # gate_proj
 5        self.w2 = RowParallelLinear(inter_dim, dim, reduce_output=reduce_output)  # down_proj
 6        self.w3 = ColumnParallelLinear(dim, inter_dim)  # up_proj
 7
 8    def forward(self, x: torch.Tensor) -> torch.Tensor:
 9        # SwiGLU: SiLU(W1·x) × W3·x
10        return self.w2(
11            (F.silu(self.w1(x).float()) * self.w3(x).float()).type_as(x)
12        )

6.2 SwiGLU 激活

公式:

$$\text{SwiGLU}(x) = \text{SiLU}(W_1 \cdot x) \odot (W_3 \cdot x)$$

$$\text{SiLU}(x) = x \cdot \sigma(x) = \frac{x}{1 + e^{-x}}$$

为什么用 SwiGLU?

  • 比 ReLU/GELU 更平滑
  • 门控机制提供更强的表达能力
  • 在 LLM 中效果更好

7. MoE 混合专家系统

7.0 通俗解释:256 个专家的"会诊"系统

问题:一个模型难以精通所有领域

想象你去医院看病:

  • 感冒找内科医生
  • 骨折找骨科医生
  • 心脏问题找心内科医生

每个医生都是某个领域的"专家"。如果只有一个全科医生,他很难在所有领域都做到顶尖。


MoE 的核心思想:让专家"会诊"

DeepSeek-V3.2 有 256 个专家 + 1 个共享专家

1256 个路由专家:每个专家擅长某类问题
2  - 专家 0-31: 可能擅长数学
3  - 专家 32-63: 可能擅长代码
4  - 专家 64-95: 可能擅长语言理解
5  - ... (自动学习的,无需人工指定)
6
71 个共享专家:处理通用信息,所有 token 都会经过

关键点:每个 token 只激活 8 个专家,不是全部 256 个!


路由决策流程

flowchart TB
    subgraph "分组路由策略"
        Token["输入 token"] --> Gate["计算 256 个专家的分数"]
        Gate --> Group["分成 8 组,每组 32 个专家"]
        Group --> G1["组1: 专家 0-31"]
        Group --> G2["组2: 专家 32-63"]
        Group --> G8["组8: 专家 224-255"]
        G1 --> GS1["组1 得分"]
        G2 --> GS2["组2 得分"]
        G8 --> GS8["组8 得分"]
        GS1 --> Top4["选出得分最高的 4 组"]
        GS2 --> Top4
        GS8 --> Top4
        Top4 --> Top8["从这 4 组中选 8 个专家"]
    end

为什么分组?

  • 直接从 256 个里选 8 个 → 可能全挤在某几个专家
  • 分组后选 → 保证专家的多样性,负载更均衡

具体例子:假设输入是一道数学编程题

 1Step 1: 计算 256 个专家的分数
 2        专家 5 (数学): 0.95
 3        专家 40 (代码): 0.92
 4        专家 67 (逻辑): 0.88
 5        专家 102 (推理): 0.85
 6        ... (其他专家分数较低)
 7
 8Step 2: 分组计算组得分 (每组取 top-2 求和)
 9        组1 (0-31):   专家5=0.95 + 专家12=0.45 = 1.40
10        组2 (32-63):  专家40=0.92 + 专家55=0.50 = 1.42
11        组3 (64-95):  专家67=0.88 + 专家80=0.60 = 1.48
12        组4 (96-127): 专家102=0.85 + 专家110=0.55 = 1.40
13        ... (其他组得分较低)
14
15Step 3: 选出 top-4 组
16        ✓ 组3 (1.48)
17        ✓ 组2 (1.42)
18        ✓ 组1 (1.40)
19        ✓ 组4 (1.40)
20
21Step 4: 从这 4 组的 128 个专家中选 top-8
22        ✓ 专家 5, 40, 67, 102, 12, 55, 80, 110
23
24Step 5: 归一化权重并缩放 ×2.5
25        权重 = [0.18, 0.17, 0.16, 0.15, 0.10, 0.09, 0.08, 0.07]

专家计算流程

flowchart LR
    subgraph "MoE Forward"
        X["输入 x"] --> Gate["Gate: 选 8 个专家"]
        Gate --> W["权重: [w1...w8]"]
        Gate --> Idx["专家索引: [e1...e8]"]

        X --> E1["专家 e1"]
        X --> E2["专家 e2"]
        X --> E8["专家 e8"]

        E1 --> Mul1["× w1"]
        E2 --> Mul2["× w2"]
        E8 --> Mul8["× w8"]

        Mul1 --> Sum["Σ 求和"]
        Mul2 --> Sum
        Mul8 --> Sum

        X --> Shared["共享专家"]
        Shared --> Sum

        Sum --> Y["输出 y"]
    end

共享专家 vs 路由专家

类型 数量 激活条件 作用
路由专家 256 被 Gate 选中才激活 处理特定领域
共享专家 1 所有 token 都激活 处理通用信息

类比

  • 路由专家 = 各科室的专家医生(按需会诊)
  • 共享专家 = 全科医生(每个病人都先经过)

显存与计算的巧妙平衡

1参数量计算:
2- 如果用一个大 MLP: 7168 × 18432 × 3 ≈ 4 亿参数
3- 实际用 MoE:
4  - 256 个小专家: 256 × (7168 × 2048 × 3) ≈ 113 亿参数
5  - 每次只激活 8 个: 8 × ... ≈ 3.5 亿参数激活
6
7结果:参数量增加 28 倍,但计算量几乎不变!

一句话总结:MoE 就是让 256 个专家"分诊",每个 token 只找最相关的 8 个专家处理,既保持了大模型的能力,又控制了计算量。


7.1 Gate 路由器

位置: model.py:646-709

作用: 为每个 token 选择 top-k 个专家。

flowchart TB
    subgraph "Gate Forward"
        X["x: [B×S, 7168]"] --> Linear["x @ weight.T"]
        Linear --> Scores["scores: [B×S, 256]"]
        Scores --> Sigmoid["sigmoid"]
        Sigmoid --> OrigScores["original_scores"]

        Scores --> AddBias["+ bias"]
        AddBias --> GroupView["view: [B×S, 8, 32]"]
        GroupView --> GroupTopK["每组 top-2 求和"]
        GroupTopK --> GroupScores["group_scores: [B×S, 8]"]
        GroupScores --> SelectGroups["top-4 组"]
        SelectGroups --> Mask["mask 其他组"]
        Mask --> Flatten["flatten: [B×S, 256]"]
        Flatten --> TopK8["top-8 专家"]
        TopK8 --> Indices["indices: [B×S, 8]"]

        OrigScores --> Gather["gather"]
        Indices --> Gather
        Gather --> Weights["weights"]
        Weights --> Normalize["归一化"]
        Normalize --> Scale["× 2.5"]
        Scale --> FinalWeights["final_weights: [B×S, 8]"]
    end

源码逐行解析:

 1class Gate(nn.Module):
 2    def __init__(self, args: ModelArgs):
 3        super().__init__()
 4        self.dim = args.dim                      # 7168
 5        self.topk = args.n_activated_experts     # 8
 6        self.n_groups = args.n_expert_groups     # 8
 7        self.topk_groups = args.n_limited_groups # 4
 8        self.score_func = args.score_func        # "sigmoid"
 9        self.route_scale = args.route_scale      # 2.5
10
11        self.weight = nn.Parameter(torch.empty(256, 7168))  # 路由权重
12        # bias 只在 dim=7168 时启用 (即完整模型)
13        self.bias = nn.Parameter(torch.empty(256, dtype=torch.float32)) if self.dim == 7168 else None
14
15    def forward(self, x):
16        # 1. 计算原始分数
17        scores = linear(x.float(), self.weight.float())  # [B*S, 256]
18
19        # 2. Sigmoid 激活
20        scores = scores.sigmoid()
21        original_scores = scores  # 保存用于最终权重
22
23        # 3. 添加 bias (用于调整专家偏好)
24        if self.bias is not None:
25            scores = scores + self.bias
26
27        # 4. 分组路由 (256 专家分成 8 组,每组 32 个)
28        if self.n_groups > 1:
29            scores = scores.view(x.size(0), self.n_groups, -1)  # [B*S, 8, 32]
30
31            # 每组取 top-2 专家分数之和作为组分数
32            group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)  # [B*S, 8]
33
34            # 选择 top-4 组
35            indices = group_scores.topk(self.topk_groups, dim=-1)[1]  # [B*S, 4]
36
37            # mask 掉未选中的组
38            mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool)
39            mask = mask.scatter_(1, indices, False)
40            scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf"))
41            scores = scores.flatten(1)  # [B*S, 256]
42
43        # 5. 从选中的组中选 top-8 专家
44        indices = scores.topk(self.topk, dim=-1)[1]  # [B*S, 8]
45
46        # 6. 从原始分数中获取权重 (不是 mask 后的分数)
47        weights = original_scores.gather(1, indices)  # [B*S, 8]
48
49        # 7. 归一化 (sigmoid 模式下)
50        weights /= weights.sum(dim=-1, keepdim=True)
51
52        # 8. 缩放
53        weights *= self.route_scale  # × 2.5
54
55        return weights, indices

7.2 Expert 专家层

位置: model.py:712-744

结构: 与 MLP 相同,但使用普通 Linear (不并行)

1class Expert(nn.Module):
2    def __init__(self, dim: int, inter_dim: int):
3        super().__init__()
4        self.w1 = Linear(dim, inter_dim)   # 7168 → 2048
5        self.w2 = Linear(inter_dim, dim)   # 2048 → 7168
6        self.w3 = Linear(dim, inter_dim)   # 7168 → 2048
7
8    def forward(self, x):
9        return self.w2((F.silu(self.w1(x).float()) * self.w3(x).float()).type_as(x))

7.3 MoE 整体结构

位置: model.py:747-804

flowchart TB
    subgraph "MoE Forward"
        X["x: [B,S,7168]"] --> Flatten["flatten: [B×S, 7168]"]
        Flatten --> Gate["Gate"]
        Gate --> Weights["weights: [B×S, 8]"]
        Gate --> Indices["indices: [B×S, 8]"]

        Flatten --> Shared["shared_experts
(MLP)"] Shared --> Y_Shared["y_shared"] subgraph "并行专家计算" Indices --> Loop["for i in local_experts"] Loop --> Find["找到选中专家 i 的 token"] Find --> Expert_i["Expert_i(x[idx])"] Expert_i --> Weighted["× weights"] Weighted --> Accumulate["累加到 y"] end Y_Shared --> Add["y += y_shared"] Accumulate --> Add Add --> AllReduce["all_reduce"] AllReduce --> Output["y: [B,S,7168]"] end

源码逐行解析:

 1class MoE(nn.Module):
 2    def __init__(self, args: ModelArgs):
 3        super().__init__()
 4        self.dim = args.dim
 5        self.n_routed_experts = args.n_routed_experts        # 256
 6        self.n_local_experts = args.n_routed_experts // world_size  # 32 (每 GPU)
 7        self.n_activated_experts = args.n_activated_experts  # 8
 8
 9        # 当前 GPU 负责的专家范围
10        self.experts_start_idx = rank * self.n_local_experts
11        self.experts_end_idx = self.experts_start_idx + self.n_local_experts
12
13        self.gate = Gate(args)
14
15        # 只创建本 GPU 负责的专家
16        self.experts = nn.ModuleList([
17            Expert(args.dim, args.moe_inter_dim)
18            if self.experts_start_idx <= i < self.experts_end_idx
19            else None
20            for i in range(self.n_routed_experts)
21        ])
22
23        # 共享专家 (所有 token 都经过)
24        self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim,
25                                  reduce_output=False)
26
27    def forward(self, x):
28        shape = x.size()
29        x = x.view(-1, self.dim)  # [B*S, 7168]
30
31        # 路由
32        weights, indices = self.gate(x)  # [B*S, 8], [B*S, 8]
33
34        y = torch.zeros_like(x, dtype=torch.float32)
35
36        # 统计每个专家被选中的次数
37        counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
38
39        # 只计算本 GPU 负责的专家
40        for i in range(self.experts_start_idx, self.experts_end_idx):
41            if counts[i] == 0:
42                continue
43            expert = self.experts[i]
44            # 找到选中专家 i 的 (token_idx, top_idx)
45            idx, top = torch.where(indices == i)
46            # 计算并加权累加
47            y[idx] += expert(x[idx]) * weights[idx, top, None]
48
49        # 加上共享专家输出
50        y += self.shared_experts(x)
51
52        # 跨 GPU 合并
53        if world_size > 1:
54            dist.all_reduce(y)
55
56        return y.type_as(x).view(shape)

8. Block 与 Transformer

8.0 通俗解释:61 层"加工流水线"

Transformer 就是一条流水线

想象一个工厂的生产线:

  • 原材料(token)从一端进入
  • 经过 61 道工序(Block)加工
  • 最后输出成品(预测的下一个词)

每道工序都做两件事:

  1. MLA 注意力:让这个位置"看看"其他位置的信息
  2. FFN/MoE:对信息进行"深加工"
flowchart LR
    subgraph "Transformer 流水线"
        Input["输入 tokens"] --> Embed["词嵌入"]
        Embed --> B0["Block 0
(Dense)"] B0 --> B1["Block 1
(Dense)"] B1 --> B2["Block 2
(Dense)"] B2 --> B3["Block 3
(MoE)"] B3 --> Dots["..."] Dots --> B60["Block 60
(MoE)"] B60 --> Norm["最终归一化"] Norm --> Head["输出投影"] Head --> Output["预测 logits"] end

为什么前 3 层用 Dense,后面用 MoE?

层级 类型 原因
0-2 层 Dense MLP 初期需要建立基础表示,Dense 更稳定
3-60 层 MoE 后期需要专业化处理,MoE 更高效

类比

  • 前 3 层 = 通识教育(所有学生学一样的基础课)
  • 后 58 层 = 专业教育(不同学生选不同的专业课)

残差连接:信息的"高速公路"

每个 Block 都有残差连接,让信息可以"跳过"某些层直接传递:

1传统方式:x → Block → y
2残差方式:x → Block → y + x  (输出 = 处理结果 + 原始输入)

为什么需要残差?

  • 防止梯度消失(61 层太深了!)
  • 让模型可以学习"不改变"(如果 Block 输出 0,结果就是原始输入)
  • 让浅层特征可以直接传到深层
flowchart LR
    subgraph "残差连接"
        X["输入 x"] --> Block["Block 处理"]
        Block --> Add["+"]
        X --> Add
        Add --> Y["输出 y = x + Block(x)"]
    end

DeepSeek 的"融合残差"优化

传统残差需要两次加法(注意力后一次,FFN 后一次)。DeepSeek 把残差融合到 RMSNorm 里:

1# 传统方式 (需要存储中间结果)
2x = x + attention(norm(x))   # 第一次加法
3x = x + ffn(norm(x))         # 第二次加法
4
5# 融合方式 (节省显存)
6x, residual = norm(x, residual)  # 归一化同时处理残差
7x = attention(x)
8x, residual = norm(x, residual)  # 归一化同时处理残差
9x = ffn(x)

完整的前向传播流程

sequenceDiagram
    participant T as Tokens
    participant E as Embedding
    participant B as Block (×61)
    participant N as Final Norm
    participant H as Head

    T->>E: [B, S] 词 ID
    E->>B: [B, S, 7168] 向量

    loop 61 次
        B->>B: RMSNorm + MLA
        B->>B: RMSNorm + FFN/MoE
    end

    B->>N: [B, S, 7168]
    N->>H: [B, 7168] (只取最后一个位置)
    H->>T: [B, 129280] logits

一句话总结:Transformer 是 61 层的加工流水线,每层先"看上下文"(MLA),再"深加工"(FFN/MoE),通过残差连接保证信息流通。


8.1 Block 结构

位置: model.py:807-851

flowchart TB
    subgraph "Block Forward"
        X["x"] --> CheckRes{"residual is None?"}
        Res["residual"] --> CheckRes

        CheckRes -->|Yes| Init["x, residual = attn_norm(x), x"]
        CheckRes -->|No| Fused["x, residual = attn_norm(x, residual)"]

        Init --> MLA["MLA(x)"]
        Fused --> MLA
        MLA --> X2["x (attention output)"]

        X2 --> FFN_Norm["x, residual = ffn_norm(x, residual)"]
        Res2["residual"] --> FFN_Norm

        FFN_Norm --> FFN["MLP / MoE"]
        FFN --> X3["x (ffn output)"]
        FFN_Norm --> Res3["residual (updated)"]

        X3 --> Output["return (x, residual)"]
        Res3 --> Output
    end

源码:

 1class Block(nn.Module):
 2    def __init__(self, layer_id: int, args: ModelArgs):
 3        super().__init__()
 4        self.attn = MLA(args)
 5        # 前 3 层用 Dense MLP,之后用 MoE
 6        self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)
 7        self.attn_norm = RMSNorm(args.dim)
 8        self.ffn_norm = RMSNorm(args.dim)
 9
10    def forward(self, x, residual, start_pos, freqs_cis, mask):
11        # 注意力前归一化 + 残差处理
12        if residual is None:
13            x, residual = self.attn_norm(x), x
14        else:
15            x, residual = self.attn_norm(x, residual)
16
17        # 注意力
18        x = self.attn(x, start_pos, freqs_cis, mask)
19
20        # FFN 前归一化 + 残差处理
21        x, residual = self.ffn_norm(x, residual)
22
23        # FFN
24        x = self.ffn(x)
25
26        return x, residual

8.2 Transformer 主类

位置: model.py:854-913

flowchart TB
    subgraph "Transformer Forward"
        Tokens["tokens: [B, S]"] --> Embed["ParallelEmbedding"]
        Embed --> H["h: [B, S, 7168]"]
        H --> Init["residual = None"]

        Init --> FreqsCis["freqs_cis = self.freqs_cis[start:end]"]
        FreqsCis --> Mask["mask = triu(-inf) if S>1 else None"]

        Mask --> Loop["for layer in self.layers"]
        Loop --> Block["h, residual = layer(h, residual, ...)"]
        Block --> Loop

        Block --> FinalNorm["h, _ = norm(h, residual)"]
        FinalNorm --> Head["head(h[:, -1])"]
        Head --> Logits["logits: [B, vocab_size/8]"]
        Logits --> AllGather["all_gather"]
        AllGather --> FullLogits["logits: [B, vocab_size]"]
    end

源码:

 1class Transformer(nn.Module):
 2    def __init__(self, args: ModelArgs):
 3        global world_size, rank
 4        world_size = dist.get_world_size() if dist.is_initialized() else 1
 5        rank = dist.get_rank() if dist.is_initialized() else 0
 6
 7        # 设置全局精度
 8        Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
 9        Linear.scale_fmt = args.scale_fmt
10
11        super().__init__()
12        self.max_seq_len = args.max_seq_len
13        self.embed = ParallelEmbedding(args.vocab_size, args.dim)
14
15        # 61 层 Block
16        self.layers = torch.nn.ModuleList()
17        for layer_id in range(args.n_layers):
18            self.layers.append(Block(layer_id, args))
19
20        self.norm = RMSNorm(args.dim)
21        # lm_head 用 FP32 以获得更精确的 logits
22        self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.float32)
23
24        # 预计算位置编码
25        self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
26
27    @torch.inference_mode()
28    def forward(self, tokens: torch.Tensor, start_pos: int = 0):
29        seqlen = tokens.size(1)
30
31        # 获取当前位置的频率
32        freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
33
34        # 因果 mask (只在 prefill 时使用)
35        mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) if seqlen > 1 else None
36
37        # 嵌入
38        h, residual = self.embed(tokens), None
39
40        # 61 层 Block
41        for layer in self.layers:
42            h, residual = layer(h, residual, start_pos, freqs_cis, mask)
43
44        # 最终归一化
45        h, _ = self.norm(h, residual)
46
47        # 输出投影 (只取最后一个 token)
48        logits = self.head(h[:, -1].float())
49
50        # 合并所有 GPU 的 logits
51        if world_size > 1:
52            all_logits = [torch.empty_like(logits) for _ in range(world_size)]
53            dist.all_gather(all_logits, logits)
54            logits = torch.cat(all_logits, dim=-1)
55
56        return logits

8.3 残差连接策略

DeepSeek-V3.2 使用"融合残差"策略:

flowchart LR
    subgraph "传统残差"
        X1[x] --> Norm1[Norm]
        Norm1 --> Attn1[Attention]
        Attn1 --> Add1["+"]
        X1 --> Add1
        Add1 --> Norm2[Norm]
        Norm2 --> FFN1[FFN]
        FFN1 --> Add2["+"]
        Add1 --> Add2
    end

    subgraph "融合残差 (节省显存)"
        X2[x] --> FusedNorm1["Norm + 加残差"]
        Res[residual] --> FusedNorm1
        FusedNorm1 --> Attn2[Attention]
        FusedNorm1 -->|new residual| FusedNorm2["Norm + 加残差"]
        Attn2 --> FusedNorm2
        FusedNorm2 --> FFN2[FFN]
        FusedNorm2 -->|new residual| Out[output]
    end

9. FP8 量化与 CUDA Kernel

位置: kernel.py

实现说明:本仓库的 kernel 由 tilelang.jit 定义(运行时/首次调用时编译),并非手写 .cu 文件;同时通过 pass_configs 关闭部分优化选项以保证兼容性与稳定性。

9.1 act_quant 激活量化

flowchart TB
    subgraph "act_quant"
        X["x: BF16 [M, N]"] --> Reshape["reshape: [M, N/128, 128]"]
        Reshape --> AbsMax["absmax per block"]
        AbsMax --> Scale["scale = absmax / 448"]
        Scale --> Clamp["x_fp8 = clamp(x/scale, -448, 448)"]
        Clamp --> Y["y: FP8 [M, N]"]
        Scale --> S["scale: FP32 [M, N/128]"]
    end

9.2 fp8_gemm FP8 矩阵乘法

flowchart TB
    subgraph "fp8_gemm"
        A["A: FP8 [M, K]"] --> GEMM["Tensor Core GEMM"]
        B["B: FP8 [N, K]"] --> GEMM
        GEMM --> Acc["累加器: FP32"]
        SA["scale_A: [M, K/128]"] --> ScaleMul["scale_A × scale_B"]
        SB["scale_B: [N/128, K/128]"] --> ScaleMul
        ScaleMul --> Acc
        Acc --> C["C: BF16 [M, N]"]
    end

9.3 fp8_index:Indexer 用的 FP8 打分 Kernel

Indexer 会把 q/k 量化到 FP8,并调用 fp8_index 计算稀疏选择所需的打分矩阵(输出 FP32 分数,供后续 top-k 选择)。


10. 推理流程

位置: generate.py

10.1 完整推理时序图

sequenceDiagram
    participant User as 用户
    participant Main as main()
    participant Gen as generate()
    participant TF as Transformer
    participant Sample as sample()

    User->>Main: 启动推理
    Main->>Main: 加载 ModelArgs 配置(如 inference/config_671B_v3.2.json)
    Main->>TF: Transformer(args)
    Main->>Main: load_model (safetensors)
    Main->>Main: 加载 tokenizer

    loop 交互循环
        User->>Main: 输入 prompt(多卡时仅 rank0 读取)
        Main->>Main: 多卡同步 prompt(dist.broadcast_object_list)
        Main->>Main: apply_chat_template
        Main->>Gen: generate(model, prompt_tokens, ...)

        rect rgb(200, 230, 255)
            Note over Gen,TF: Prefill 阶段
            Gen->>TF: forward(tokens[:, 0:len], start_pos=0)
            TF-->>Gen: logits
        end

        rect rgb(255, 230, 200)
            Note over Gen,TF: Decode 阶段
            loop until EOS
                Gen->>Sample: sample(logits, temp)
                Sample-->>Gen: next_token
                Gen->>TF: forward(next_token, start_pos=cur)
                TF-->>Gen: logits
            end
        end

        Gen-->>Main: completion_tokens
        Main-->>User: decoded text
    end

10.2 sample 函数

1def sample(logits, temperature: float = 1.0):
2    logits = logits / max(temperature, 1e-5)
3    probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
4    # Gumbel-Max trick: 等价于 multinomial 但更高效
5    return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)

11. 权重转换流程

位置: convert.py

11.1 转换流程图

flowchart TB
    subgraph "HuggingFace 格式"
        HF["163 个 safetensors 文件"]
    end

    subgraph "convert.py"
        HF --> Load["加载每个文件"]
        Load --> Skip["跳过 layer.61 (MTP)"]
        Skip --> Rename["重命名 keys"]
        Rename --> Shard["按 MP 分片"]
    end

    subgraph "输出"
        Shard --> Out["model{0-7}-mp8.safetensors"]
    end

11.2 Key 映射表

HuggingFace 名称 转换后名称 分片
embed_tokens embed dim=0
input_layernorm attn_norm None
post_attention_layernorm ffn_norm None
q_proj wq dim=0
q_a_proj wq_a None
q_a_layernorm q_norm None
q_b_proj wq_b dim=0
kv_a_proj_with_mqa wkv_a None
kv_a_layernorm kv_norm None
kv_b_proj wkv_b dim=0
o_proj wo dim=1
gate gate None
gate_proj w1 dim=0
up_proj w3 dim=0
down_proj w2 dim=1
norm norm None
lm_head head dim=0
wk wk None
k_norm k_norm None
weights_proj weights_proj None

额外重命名规则(实现中直接 str.replace):去掉前缀 model.self_attn → attnmlp → ffnweight_scale_inv → scalee_score_correction_bias → bias。另外会跳过 model.layers.61.*(MTP 层)。


附录 A: 张量维度速查表

组件 张量 维度 (DeepSeek-V3.2)
Embedding weight [129280/8, 7168]
MLA wq_a.weight [1536, 7168]
wq_b.weight [128×192/8, 1536]
wkv_a.weight [576, 7168]
wkv_b.weight [128×256/8, 512]
wo.weight [7168, 128×128/8]
kv_cache [8, 163840, 512]
pe_cache [8, 163840, 64]
Indexer wq_b.weight [64×128, 1536]
wk.weight [128, 7168]
k_cache [8, 163840, 128] FP8
MLP w1.weight [18432/8, 7168]
w2.weight [7168, 18432/8]
w3.weight [18432/8, 7168]
Expert w1.weight [2048, 7168]
w2.weight [7168, 2048]
w3.weight [2048, 7168]
Gate weight [256, 7168]
Head weight [129280/8, 7168]

附录 B: 关键公式汇总

MLA 注意力

$$ \begin{aligned} c_Q &= \text{RMSNorm}(W_{QA} \cdot x) & \text{压缩 Q} \ Q &= W_{QB} \cdot c_Q & \text{展开 Q} \ [c_{KV}, k_{pe}] &= W_{KVA} \cdot x & \text{压缩 KV} \ c_{KV} &= \text{RMSNorm}(c_{KV}) & \text{归一化} \ [k_{nope}, v] &= W_{KVB} \cdot c_{KV} & \text{展开 KV} \ \text{Attn} &= \text{softmax}\left(\frac{QK^T}{\sqrt{d}} + M_{idx}\right) V \end{aligned} $$

MoE 路由

$$ \begin{aligned} s_i &= \sigma(W_{gate} \cdot x + b)i \ G_j &= \sum{i \in \text{group}_j} \text{top2}(s_i) \ \text{groups} &= \text{topk}(G, k=4) \ \text{experts} &= \text{topk}(s[\text{groups}], k=8) \ w &= \text{normalize}(s[\text{experts}]) \times 2.5 \end{aligned} $$

RMSNorm

$$\text{RMSNorm}(x) = \frac{x}{\sqrt{\frac{1}{d}\sum_i x_i^2 + \epsilon}} \cdot \gamma$$

SwiGLU

$$\text{SwiGLU}(x) = \text{SiLU}(W_1 x) \odot W_3 x$$


附录 C: 运行命令

转换权重

1cd inference
2python convert.py \
3    --hf-ckpt-path /path/to/hf/checkpoint \
4    --save-path /path/to/converted \
5    --n-experts 256 \
6    --model-parallel 8

单卡推理

1cd inference
2python generate.py \
3    --ckpt-path /path/to/converted \
4    --config config_671B_v3.2.json \
5    --interactive \
6    --temperature 0.6

多卡推理

1cd inference
2torchrun --nproc_per_node=8 generate.py \
3    --ckpt-path /path/to/converted \
4    --config config_671B_v3.2.json \
5    --interactive \
6    --temperature 0.6

附录 D: 参考资料