传统的transformer在处理长数字序列时,难以准确地跟踪和表示每个数字的位置,导致在进行多步骤和复杂运算时性能不理想。 这篇论文解决了transformer在处理算术任务(如多位数加法、乘法和排序)时表现不佳的问题。 <strong>解决的问题:</strong> 通过引入一种新的位置嵌入方法(Abacus Embeddings),研究人员显著提高了transformer处理复杂算术任务的能力。使其在训练仅一天后,能够准确处理多达100位数的加法问题,并且这种改进也能推广到乘法和排序等其他多步骤推理任务。 <ol> <li><strong>位置记忆问题</strong>:transformer模型在处理一长串数字时,无法很好地记住每个数字的位置,导致在处理多步运算(比如多位数加法、乘法)时出错。</li> <li><strong>算术运算能力不足</strong>:由于位置记忆的问题,传统transformer在处理复杂的算术任务时,准确率很低,特别是当数字长度超出训练数据的范围时。</li> </ol> <strong>解决方案:</strong> <ol> <li><strong>引入新的位置嵌入方法</strong>:论文提出了一种叫做“Abacus嵌入”的新方法,帮助模型更好地记住每个数字的位置。这种方法在每个数字上添加一个表示其在序列中相对位置的编码。</li> <li><strong>架构改进</strong>: <ul> <li><strong>输入注入</strong>:在每层计算中都引入原始输入信息,帮助模型在每一步计算中都能参考原始数字的位置。</li> <li><strong>循环transformer</strong>:通过让模型重复使用某些计算层次,使模型能够更好地进行多步运算。</li> </ul> </li> </ol> 通过这些改进,论文展示了transformer模型在处理多位数加法和乘法时的显著提升。使用这些新方法的模型在加法任务中达到了99%的准确率,甚至能够处理长度比训练数据中最长数字还要长的数字。同时,这些改进也在其他需要多步推理的任务(如排序)中表现良好。 <h3><img class="aligncenter size-full wp-image-8892" src="https://img.xiaohu.ai/2024/06/Jietu20240601-204356@2x.jpg" alt="" width="1416" height="702" />详细方法介绍</h3> <h4>1. Abacus嵌入(Abacus Embeddings)</h4> Abacus嵌入是一种新的位置嵌入方法,用于帮助transformer模型更好地表示和记住数字在序列中的相对位置。 <ul> <li><strong>方法概述</strong>: <ul> <li>为每个数字添加一个嵌入,编码其相对于数字开头的位置。</li> <li>使用这种嵌入,模型可以更容易地理解每个数字在长序列中的位置,从而在处理多步算术运算时表现更好。</li> </ul> </li> <li><strong>具体实现</strong>: <ul> <li>对于每个输入数字序列,生成一组位置嵌入,例如,数字123的嵌入可能是β, β+1, β+2,其中β是一个随机偏移量。</li> <li>训练时,偏移量β在一个固定范围内随机选择,确保模型能够看到各种位置嵌入。</li> </ul> </li> </ul> <h4>2. 输入注入(Input Injection)</h4> 输入注入是一种在每层计算中都引入原始输入信息的方法,帮助模型在每一步计算中都能参考原始数字的位置。 <ul> <li><strong>方法概述</strong>: <ul> <li>在transformer的每个解码层之间插入输入注入,将输入特征添加到每一层的隐藏表示中。</li> <li>这种方法确保模型在每一层都能访问原始输入数据,减少信息在多层传递中的丢失。</li> </ul> </li> <li><strong>具体实现</strong>: <ul> <li>将嵌入后的输入特征直接添加到每一层的输入中,形成一种跳跃连接(skip connection)。</li> </ul> </li> </ul> <h4>3. 循环transformer架构(Looped Transformer Architecture)</h4> 循环transformer是一种包含循环层的transformer架构,通过重复使用同一参数集来增加模型的有效深度。 <ul> <li><strong>方法概述</strong>: <ul> <li>使用循环块(recurrent block),即一组具有独特权重的解码层,通过多次重复这些块来实现更深的模型。</li> <li>循环transformer有助于改进多步推理过程中的推理能力。</li> </ul> </li> <li><strong>具体实现</strong>: <ul> <li>模型的每个循环块包含若干解码层,循环次数可以调整,例如,一个循环块包含8层解码器,循环2次,则模型的有效深度为16层。</li> <li>在每次前向传播中,输入注入确保每一层都能访问原始输入数据。</li> </ul> </li> </ul> <h4>4. 进阶损失计算(Progressive Loss Computation)</h4> 进阶损失计算方法通过在训练过程中使用渐进的损失计算,帮助模型在测试时对更难任务的推广能力。 <ul> <li><strong>方法概述</strong>: <ul> <li>在一次前向传播中,使用标准的循环次数和随机较少的循环次数,计算两者的损失值的凸组合。</li> <li>这种方法提高了模型在测试时对更复杂任务的推广能力。</li> </ul> </li> <li><strong>具体实现</strong>: <ul> <li>在每次训练中,计算正常循环次数的损失值和较少循环次数的损失值,并计算两者的加权和作为最终损失值。</li> </ul> </li> </ul> <h3><img class="aligncenter size-full wp-image-8891" src="https://img.xiaohu.ai/2024/06/Jietu20240601-204418@2x.jpg" alt="" width="1570" height="322" /></h3> <h3>实验与结果详细介绍</h3> <h4>实验设置</h4> <ol> <li><strong>数据集构建</strong>: <ul> <li><strong>加法任务</strong>:生成长度最多为i和j的操作数组合,形成包含2000万样本的训练集,i = j。训练集中每对操作数长度的组合被均匀采样,确保所有长度对在训练过程中均匀出现。</li> <li><strong>乘法任务</strong>:类似于加法任务,但操作变为乘法。</li> <li><strong>排序任务</strong>:输入为多个反向数字,输出为字符索引按升序排列的结果。</li> <li><strong>评估方式</strong>:测试分为三类:训练分布内(ID),训练分布外(OOD),极端训练分布外(100+ digit OOD)。</li> </ul> </li> <li><strong>模型架构</strong>: <ul> <li><strong>标准transformer</strong>:多层堆叠的解码器层。</li> <li><strong>标准transformer + 输入注入</strong>:在每层之间插入输入注入。</li> <li><strong>循环transformer</strong>:包含循环层的transformer架构,具有输入注入和渐进损失计算。</li> </ul> </li> <li><strong>训练细节</strong>: <ul> <li>使用Nvidia RTX A4000 GPU训练24小时,限制计算量为8 exaFLOP。</li> <li>损失计算仅在答案部分进行。</li> </ul> </li> </ol> <h4>实验结果</h4> <ol> <li><strong>加法任务</strong>: <ul> <li><strong>Abacus嵌入提升表现</strong>: <ul> <li>在处理长达100位数的加法问题时,使用Abacus嵌入的模型可以达到99%的准确率。</li> <li>与FIRE和NoPE嵌入相比,Abacus嵌入在训练分布外(OOD)表现更好。</li> <li>使用Abacus嵌入结合输入注入和循环transformer的架构改进,在训练数据最大操作数长度为20位的情况下,实现了6倍的长度推广。</li> </ul> </li> <li><strong>不同架构的比较</strong>: <ul> <li>标准transformer和标准transformer + 输入注入的模型准确率显著低于循环transformer模型。</li> <li>在最优架构(循环transformer,8层循环块,两次循环)下,Abacus嵌入的模型在OOD问题上的准确率几乎达到100%。</li> </ul> </li> </ul> </li> <li><strong>乘法任务</strong>: <ul> <li><strong>Abacus嵌入的优势</strong>: <ul> <li>使用Abacus嵌入的循环transformer模型在处理长达15位数的乘法问题时,达到了近乎完美的分布内准确率。</li> <li>在分布内的最难问题上,Abacus嵌入明显优于FIRE嵌入。</li> </ul> </li> </ul> </li> <li><strong>排序任务</strong>: <ul> <li><strong>嵌入的影响</strong>: <ul> <li>在标准transformer架构下,Abacus嵌入提升了排序任务的OOD表现。</li> <li>使用Abacus嵌入和FIRE嵌入结合的模型,在处理长度为30位的数组时,表现优于单独使用任一种嵌入。</li> </ul> </li> </ul> </li> <li><strong>架构变体</strong>: <ul> <li><strong>渐进损失计算的效果</strong>: <ul> <li>循环transformer模型通过渐进损失计算,在较小的循环次数下也能保持高准确率,进一步提升了OOD表现。</li> </ul> </li> </ul> </li> <li><strong>综合性能</strong>: <ul> <li><strong>总体表现</strong>: <ul> <li>Abacus嵌入结合循环transformer架构,显著提升了transformer在算术任务中的表现。</li> <li>提出了适用于更复杂多步骤推理任务的通用方法。</li> </ul> </li> </ul> </li> </ol> <strong>可视化结果</strong> <ol> <li><strong>加法任务</strong>:准确率随操作数长度变化的热图,展示了Abacus嵌入在训练分布内和分布外的出色表现。</li> <li><strong>乘法任务</strong>:循环transformer模型在处理乘法任务时的精确匹配准确率图,显示了Abacus嵌入在最难问题上的优势。</li> <li><strong>排序任务</strong>:不同嵌入和架构在排序任务中的表现对比,Abacus嵌入结合FIRE嵌入在各种场景下均表现出色。</li> </ol> <h3><img class="aligncenter size-full wp-image-8890" src="https://img.xiaohu.ai/2024/06/Jietu20240601-204430@2x.jpg" alt="" width="1618" height="548" /></h3> <h3>结论</h3> <ol> <li><strong>改进位置嵌入的有效性</strong>: <ul> <li><strong>Abacus嵌入</strong>显著提高了模型在处理多位数加法和乘法问题时的准确性。</li> <li>这种嵌入方法能够更好地捕捉每个数字在序列中的相对位置,从而提高模型在长数字序列中的表现。</li> </ul> </li> <li><strong>架构改进的效果</strong>: <ul> <li><strong>输入注入</strong>和<strong>循环transformer架构</strong>进一步增强了模型的多步骤推理能力。</li> <li>输入注入帮助模型在每一层计算中保持对原始输入数据的参考,减少信息丢失。</li> <li>循环transformer架构通过重复使用解码层,提高了模型的有效深度和复杂任务的处理能力。</li> </ul> </li> <li><strong>实验验证</strong>: <ul> <li>实验结果显示,使用Abacus嵌入结合循环transformer架构的模型,在处理长达100位数的加法问题时,准确率达到99%。</li> <li>在乘法和排序等任务中,这些改进方法同样表现出色,证明了其广泛的适用性和强大的外推能力。</li> </ul> </li> <li><strong>推广能力</strong>: <ul> <li>这些改进方法不仅提高了模型在训练数据范围内的表现,还显著提升了模型在训练数据范围之外(OOD)问题上的准确性。</li> <li>模型能够处理比训练数据更复杂和更长的数字序列,展示了其强大的推广能力和鲁棒性。</li> </ul> </li> <li><strong>未来应用的潜力</strong>: <ul> <li>论文提出的方法为解决算术任务提供了新的思路,也为进一步研究transformer在自然语言处理和算法推理任务中的应用提供了重要参考。</li> <li>未来可以探索这些嵌入和架构改进在更广泛的任务中的应用,进一步提升大型语言模型的算法推理能力。</li> </ul> </li> </ol> GitHub:<a href="https://github.com/mcleish7/arithmetic" target="_blank" rel="noopener">https://github.com/mcleish7/arithmetic</a> 论文:<a href="https://arxiv.org/abs/2405.17399" target="_blank" rel="noopener">https://arxiv.org/abs/2405.17399</a>