03 RoPE
2026/6/20 12:45:33 网站建设 项目流程

03 RoPE


1. 为什么需要 RoPE?

1.1 位置编码要解决什么问题?

Transformer 的 Attention 机制本身是置换不变的——把输入序列打乱,Attention 的输出(在没有位置信息的情况下)也会被打乱。模型不知道"第一个 token"和"第五个 token"有什么区别。

所以需要在输入中注入位置信息

1.2 绝对位置编码的痛点

早期方案(如原始 Transformer 的正弦波位置编码、GPT-2 的可学习位置编码)都是绝对位置编码

input=token_embedding+position_embedding\text{input} = \text{token\_embedding} + \text{position\_embedding}input=token_embedding+position_embedding

致命问题:模型在训练时只见过位置 0~4095(4K 上下文),推理时给一个位置 5000 的 token,位置编码是"没见过"的——模型直接懵了。

这就是上下文长度外推(Context Extension)问题的根源。

1.3 RoPE 的核心洞察

不直接告诉模型"这是第几个 token",而是让 Attention 计算自然包含 token 之间的相对距离

具体做法:对 Query 和 Key 向量施加一个旋转,旋转角度正比于 token 的位置编号。当计算内积⟨qm,kn⟩\langle q_m, k_n \rangleqm,kn时,结果只依赖于相对位置(m−n)(m - n)(mn)

q_m · k_n = f(内容相似度, m - n) ← 只和相对距离有关! ↑ 不是绝对位置 m 或 n

这带来了一个重要能力:训练时用 4K 序列,推理时可以外推到 16K 甚至 128K(配合后续的 RoPE Scaling 技术)。


2. 数学原理:借用复数的旋转

2.1 二维旋转的数学

把一个二维向量(x1,x2)(x_1, x_2)(x1,x2)旋转角度θ\thetaθ

[x1′x2′]=[cos⁡θ−sin⁡θsin⁡θcos⁡θ][x1x2]\begin{bmatrix} x_1' \\ x_2' \end{bmatrix} = \begin{bmatrix} \cos\theta & -\sin\theta \\ \sin\theta & \cos\theta \end{bmatrix} \begin{bmatrix} x_1 \\ x_2 \end{bmatrix}[x1x2]=[cosθsinθsinθcosθ][x1x2]

用复数可以写得更优雅:z=x1+ix2z = x_1 + i x_2z=x1+ix2,旋转θ\thetaθ就是z′=z⋅eiθz' = z \cdot e^{i\theta}z=zeiθ

其中eiθ=cos⁡θ+isin⁡θe^{i\theta} = \cos\theta + i\sin\thetaeiθ=cosθ+isinθ(欧拉公式)。

2.2 推广到高维:分组旋转

实际的 Query/Key 向量是ddd维的(比如d=128d=128d=128)。RoPE 的做法是:

  1. ddd维向量两两配对,分成d/2d/2d/2
  2. 每组(2 维)做一个独立的旋转
  3. iii组的旋转角度是θi=10000−2i/d\theta_i = 10000^{-2i/d}θi=100002i/d
  4. 对于位置mmm的 token,第iii组旋转m⋅θim \cdot \theta_imθi

为什么不同维度组旋转速度不同?θi\theta_iθiiii指数衰减——低维度组旋转快(高频),高维度组旋转慢(低频)。这模拟了"短距离信息靠高频捕捉,长距离信息靠低频捕捉"的信号处理直觉。

2.3 频率预计算

# 频率公式:θ_i = 10000^{-2i/d},i = 0, 1, ..., d/2-1freqs=1.0/(10000**(torch.arange(0,dim,2)[:dim//2].float()/dim))# 结果:freqs = [1/1.00, 1/1.15, 1/1.33, ..., 1/10000]# 高频 ←――――――――――――――――→ 低频

对于每个位置mmm(0 到 seq_len-1),该位置的旋转角度矩阵是m * freqs

用极坐标生成复数

# torch.polar(abs, angle) 生成 abs * e^{i * angle}freqs_cis=torch.polar(torch.ones_like(angles),angles)# shape: [seq_len, dim//2]

freqs_cis[m, i]就是位置 m 在第 i 组维度的旋转复数ei⋅m⋅θie^{i \cdot m \cdot \theta_i}eimθi

2.4 应用旋转:复数乘法

把 Query 的最后一维(ddd)reshape 成[d/2,2][d/2, 2][d/2,2],然后用torch.view_as_complex解释为d/2d/2d/2个复数:

# xq shape: [B, L, num_heads, d]# Step 1: reshape → [B, L, num_heads, d/2, 2]# Step 2: view_as_complex → [B, L, num_heads, d/2] (复数张量)xq_complex=torch.view_as_complex(xq.reshape(*xq.shape[:-1],-1,2))

然后广播复数乘法:

xq_rotated_complex=xq_complex*freqs_cis# 复数旋转!# 再转回实数:xq_rotated=torch.view_as_real(xq_rotated_complex).flatten(3)

复数乘法自动实现了旋转矩阵
(a+bi)×(cos⁡θ+isin⁡θ)=(acos⁡θ−bsin⁡θ)+i(asin⁡θ+bcos⁡θ)(a+bi) \times (\cos\theta + i\sin\theta) = (a\cos\theta - b\sin\theta) + i(a\sin\theta + b\cos\theta)(a+bi)×(cosθ+isinθ)=(acosθbsinθ)+i(asinθ+bcosθ)
这正好等于[cos⁡θ−sin⁡θsin⁡θcos⁡θ][ab]\begin{bmatrix} \cos\theta & -\sin\theta \\ \sin\theta & \cos\theta \end{bmatrix} \begin{bmatrix} a \\ b \end{bmatrix}[cosθsinθsinθcosθ][ab]


3. 代码实现

3.1 预计算频率表

defprecompute_freqs_cis(dim:int,end:int,theta:float=10000.0):""" 计算复数指数频率张量。 返回 shape: [end, dim//2],dtype=complex64 freqs_cis[m, i] = e^{i * m * θ_i} """# Step 1: 计算每个维度组的频率 θ_i = 10000^{-2i/d}freqs=1.0/(theta**(torch.arange(0,dim,2)[:(dim//2)].float()/dim))# Step 2: 每个位置 m 对应的角度 m * θ_it=torch.arange(end,dtype=torch.float32)# [end]angles=torch.outer(t,freqs)# [end, dim//2]# Step 3: 用极坐标生成复数 e^{i * angle}freqs_cis=torch.polar(torch.ones_like(angles),angles)returnfreqs_cis

torch.outer的作用:把t[0..seq_len]freqs[0..d/2]做外积,得到[seq_len, d/2]的角度矩阵。每个元素angles[m, i] = m * θ_i

3.2 应用旋转编码 —— 关键:必须 FP32 Upcast

defapply_rotary_emb(xq,xk,freqs_cis):""" 对 Query 和 Key 施加 RoPE 旋转。 xq, xk: [B, L, num_heads, head_dim] freqs_cis: [seq_len, head_dim//2] """# Step 1: 转为复数 — ⚠️ 先升精度到 FP32!# 复数乘法在 FP16 下极易产生 NaN,LLaMA 源码强制 FP32xq_=torch.view_as_complex(xq.float().reshape(*xq.shape[:-1],-1,2))xk_=torch.view_as_complex(xk.float().reshape(*xk.shape[:-1],-1,2))# Step 2: 调整 freqs_cis 形状以广播# freqs_cis: [L, d/2] → [1, L, 1, d/2]freqs_cis=reshape_for_broadcast(freqs_cis,xq_)# Step 3: 复数乘法(旋转)并转回实数xq_out=torch.view_as_real(xq_*freqs_cis).flatten(3)xk_out=torch.view_as_real(xk_*freqs_cis).flatten(3)# Step 4: 转回输入精度returnxq_out.type_as(xq),xk_out.type_as(xk)

3.3 为什么必须 Upcast 到 FP32?

这是 RoPE 实现中最容易踩的坑。

复数乘法(a+bi)×(c+di)(a+bi) \times (c+di)(a+bi)×(c+di)内部做了 4 次浮点乘法和 2 次加减。在 FP16 下:

  • 精度只有 ~3.3 位有效十进制数字
  • 多次乘法后误差快速累积
  • 复数运算的误差传播更复杂(实部虚部交叉影响),极易产生 NaN

LLaMA 官方源码强制在 FP32 下做 RoPE 旋转,算完再转回 FP16/BF16。如果你漏了.float(),模型在训练几万步后精度会悄悄退化。


4. 工业实现对照

4.1 LLaMA 源码的关键差异

维度本教程(复数法)LLaMA 官方(实数法)
实现方式view_as_complex+ 复数乘法手动 split + cos/sin + 交叉乘加
可读性⭐⭐⭐⭐⭐ 数学直觉清晰⭐⭐ 代码冗长
性能依赖 PyTorch 复数优化编译器更容易优化
精度必须FP32也需要 FP32

LLaMA 官方用实数法是因为编译器的复数支持在旧版本不够好,但原理完全等价。

4.2 上下文外推(Context Extension)—— 训练 4K,推理 128K

模型在 4K 序列训练,推理时如何支持 16K?这就是 RoPE Scaling 要解决的问题:

方法做法代表
线性插值位置索引除以缩放因子:m → m/sLLaMA 2 (32K)
NTK-aware调大基频:θ = 10000 → 100000,让高频减速Qwen (128K)
YaRN高频插值 + 低频外推,按维度组分别处理学术方案

核心思想都是压缩位置空间降低旋转速度,让训练时见过的旋转角度能覆盖推理时的更长上下文。


5. 踩坑记录

5.1 忘记 Upcast 到 FP32

  • 现象:训练几万步后 loss 逐渐发散或精度退化,FP16 下更严重
  • 根因:复数乘法在 FP16 下误差累积。(a+bi)*(c+di)的每一步乘加都在损失精度
  • 解决:在view_as_complex前加.float(),return 前.type_as(xq)转回

5.2 旋转角度生成顺序写反

  • 现象:模型完全不收敛,perplexity 巨高
  • 根因torch.outer(t, freqs)torch.outer(freqs, t)生成的角度矩阵形状不同([seq_len, d/2]vs[d/2, seq_len]),导致广播到错误的维度
  • 解决torch.outer(t, freqs)— t 是位置(行),freqs 是频率(列)

5.3 只对 Query 旋转没对 Key 旋转

  • 现象:位置信息似乎没生效,长序列效果与非位置模型差不多
  • 根因:RoPE 必须同时旋转 Q 和 K,内积⟨q,k⟩\langle q, k \rangleq,k的交叉项才会自然出现相对位置m−nm-nmn
  • 解决apply_rotary_emb必须同时处理 xq 和 xk

5.4flatten(3)的参数记错

  • 现象:输出形状变成[B, L, H, d/2, 2]而不是[B, L, H, d]
  • 根因view_as_real在最后增加了一个维度 2(实部/虚部),需要flatten(3)合并回[d]
  • 解决flatten(3)表示从第 3 维(0-indexed)开始压平

6. 延伸思考

  • 为什么不对 Value 也做 RoPE?实验发现对 V 做旋转没有额外收益。因为 Attention 的"位置感知"只需要⟨qm,kn⟩\langle q_m, k_n \rangleqm,kn体现相对位置——输出是 V 的加权和,V 本身不需要知道位置
  • RoPE 和 ALiBi 的区别:ALiBi 直接在 Attention score 上加一个相对位置 bias(不修改 Q/K 本身),更简单但表达力弱于 RoPE
  • Triton 融合实现:在 Triton 中可以把整个 RoPE kernel 写成一个 fused kernel,消除中间张量的显存读写
  • 与后续内容的关系:RoPE 是 Attention 实现(04_Attention_MHA_GQA)的前置知识。理解 RoPE 后,MHA/GQA 中 Q/K 的初始化和前向传播就顺理成章了

RoPE 的优雅之处在于:用复数旋转这个 200 年前的数学工具,解决了大模型"位置泛化"这个 2023 年的工程难题。复数不是装饰品,是实实在在的工程抓手。

需要专业的网站建设服务?

联系我们获取免费的网站建设咨询和方案报价,让我们帮助您实现业务目标

立即咨询