我们的贡献
为解决标准 Newton-Schulz 算法的缺点,引入 Gram Newton-Schulz 算法,在万亿参数的 MoE 模型中,可将优化器时间最多减少 50%。该算法在小的方形对称 Gram 矩阵上迭代,降低浮点运算成本,更多使用对称 GEMM 内核。贡献包括将标准算法重写为朴素 Gram Newton-Schulz 算法、研究其数值特性并改进为稳定算法、实现自定义对称矩阵乘法内核、用其取代 Muon 的 Newton-Schulz 例程得到 GramMuon 优化器,且发布开源实现。
Muon 回顾
Muon 是训练先进语言模型的首选优化器,与 AdamW 相比,达到给定损失所需优化器步骤更少,但每个步骤计算成本更高,开销源于 Newton-Schulz 正交化过程。Muon 更新规则涉及动量矩阵和极分解操作,使用 Newton-Schulz 方法近似极分解。标准 Newton-Schulz 算法实现有特定步骤,后续工作试图改进 Muon,但大多未改变 Newton-Schulz 例程的实际运行时间。标准 Newton-Schulz 算法运行时间分析显示其存在对称矩阵乘法和对矩阵纵横比依赖的缺点。
Gram Newton-Schulz 算法
该算法通过在小的方形对称 Gram 矩阵上迭代,减少昂贵的矩形矩阵乘法数量,输出与标准 Newton-Schulz 算法相同但计算成本显著降低。其基于特定公式,将迭代多项式方法转换为近似平方根倒数的方法,有朴素 Gram Newton-Schulz 算法版本。该算法与 Polar Express 论文附录 F 方法相关,但在公式、内核使用和稳定性分析方面有超越。朴素 Gram Newton-Schulz 算法运行时间计算表明,在典型 Muon 应用中,比使用对称 GEMM 运算的标准算法节省 55% 浮点运算次数,比不使用对称 GEMM 运算的典型实现节省 68%。
朴素 Gram Newton-Schulz 算法的不稳定性
使用朴素 Gram Newton-Schulz 算法训练 Transformer 大语言模型会出现损失值飙升和输出充满无穷大的问题。通过跟踪中间矩阵的特征值,发现不稳定性源于 Gram 矩阵存在虚假负特征值和特征向量漂移。