text-generation-inference 提高Santacoder和Starcoder(以及其他)的推理速度

4xy9mtcn  于 7个月前  发布在  其他
关注(0)|答案(7)|浏览(168)

我进行了一些广泛的调查、测试和基准测试,确定了以下几点可以加速Bigcode模型(以及大多数文本生成推理模型)的推理过程:

  1. 仅使用FlashAttention进行预填充。这是作者的建议,因为FlashAttention内核依赖于较高的查询长度来实现良好的并行化,而且FlashAttention需要对每个令牌的输入/输出/KV缓存做很多额外的工作。

  2. 尽可能地向量化预处理/后处理操作,即避免循环(尤其是针对cuda操作)。Warpers / logit处理器已经在feat(server): support vectorized warpers in flash causal lm #317中进行了向量化,而causal_lm中的其余部分在[Prototype] Vectorized causal lm #272中有一个原型实现(flash_causal_lm较难向量化,但根据上述观点,causal_lm应该更可取)。

  3. 对KV缓存进行某种形式的预分配,并将键长度填充为8的倍数。完全静态预分配的Tensor会带来复杂性,因为需要连接/过滤批次,但提前预先分配几个令牌以便在每N个令牌上运行缓慢的连接而不是全部连接是相对容易的。(再次强调,这在FlashAttention中是无法实现的。)将键长度填充为8的倍数也提供了很高的加速比,因此N=8是一个最低要求(尽管更高更好)。

  4. 仅在请求时计算details(logprobs、预填充数据等)(Make generation details optional #288)。这些操作耗时较长,迫使计算整个模型头(见下文),但结果几乎总是被丢弃。

  5. 仅在预填充的最后一个令牌上计算模型头(除非我们需要它们用于details)。这样可以节省一些时间,更重要的是避免内存瓶颈。

  6. 仅在提供种子时使用确定性生成。否则,需要在循环中进行采样,因为Pytorch不支持向量化生成器。

  7. 修剪Python代码。避免任何不必要的函数调用(尽可能使用内联),属性获取等,因为这些最终会导致CPU延迟增加。避免继承nn.Module,因为这会在__call__getattr上添加很多冗余(钩子)。在测试中,我通过这种方式成功地将Santacoder的最小延迟降低了20%以上。
    未来的工作(需要更多调查):

  8. 尝试比较更多的融合内核。对于融合softmax,可以与Jit(在[Prototype] Vectorized causal lm #272中使用)和Megatron的实现(可能更好)进行比较。比较融合和标准层归一化(下面的结果似乎与融合相矛盾)。尝试在MLP中融合密集(带有gelu)的内核(或者尝试Jit?)

  9. 通过预分配和/或重用Tensor减少内存分配。主要障碍是许多操作仍然不支持out参数,因此需要一些简单的cpp工作。

  10. 将CPU密集部分( Block )写入cpp。这不会太困难,对于较小的模型会有很大帮助提高延迟,但如果使用cuda图,可能不需要这样做。

  11. 为cuda图添加支持,至少用于解码。我已经展示了它们如何与动态形状一起工作(使用大量的图),并且它们为Santacoder带来了很大的加速比(Starcoder也带来了一点加速比),但由于静态KV缓存位置,它们给批处理 concatenate / filter 带来了复杂性。一个选项是始终使用相同的批量大小进行解码(或者一些预先确定的值,例如2的幂次方),这样每次都需要昂贵地在每个 filter 上 Shuffle 数据是可以接受的,因为(Santacoder)解码延迟基本上不受批量大小的影响。

  12. 深入研究Tensor并行性。我知道它已经在文本生成推理中实现了,但我自己还没有深入研究过。

svujldwt

svujldwt1#

一些基准测试结果,比较了几个实现:

  1. flash : flash_santacoder ,当前的实现。
  2. causal : HF变压器模型的 gpt_bigcode ,运行于 causal_lm
  3. vector : HF变压器模型的 gpt_bigcode ,运行于 vectorized_causal_lm ,来自 [Prototype] Vectorized causal lm #272(上文提到的2)。
  4. bigcode :Bigcode变压器仓库中的 gpt_bigcode 模型,经过少量调整和裁剪以适应文本生成推理和 vectorized_causal_lm (上文提到的1、3、4、5、6)
  5. bigcode2 : bigcode ,采用了从 flash_santacoder 中获得的一些额外优化,主要是 FastLinearFastLayerNorm 层。同时对注意力掩码进行了一些简化。
  6. bigcode3 : bigcode2 ,经过裁剪的Python代码(上文提到的7)
    注意: flashcausal 是基于提交 5a58226 (5月16日)的,所以可能缺少最新的优化。
    另外请注意:曲线已经平滑处理,否则它们会在没有关键长度填充的情况下剧烈波动( causalvector )

Santacoder解码

  • 对于批量大小=1,CPU始终是瓶颈。 flash 是最快的,并且 bigcode1/2/3 之间有很大的差距。Megatron的融合softmax可能会使 bigcode3flash 几乎相当(我仍然认为 flash 会更快,因为它有更少的内核)
  • flashcausal 在高批量大小下表现非常差,尤其是对于长序列。这是由于非矢量化操作和FlashAttention性能不佳所致。
  • vector 已经将批量大小开销降至最低。
  • bigcode1/2/3 显示了其他改进。
  • 令人惊讶的是,对于较大的序列, x1m33n3x 时, x34n3x 甚至比 x35n3x 更慢。这可能是由于子优化的融合层归一化?

x36d3x
x37d3x
x38d3x

Santacoder预填充

  • x39d3x n40d4x 的表现非常差(没有FlashAttention)
  • x39d5x n40d6x 也不太好,似乎与 x39d7x 中的大量处理有关。
  • x39d8x n40d9x 的表现最好,它们非常相似(除了当CPU受限时的 bs=
drnojrws

drnojrws2#

Starcoder解码

  • 类似于Santacoder,但flash在批量大小为1时已经效率低下,甚至比causal更差。
  • 对于小批量大小的延迟,瓶颈在于读取权重(15.5e9参数 * 2B/param / 2039e9B/s = 15.2毫秒),因此Tensor并行化可能会降低它。
  • causal对于大序列表现不佳,不确定原因。
  • 再次,bigcode2/3bigcode更差,怀疑是融合的层归一化。
  • 对于批量大小为256,小序列的时间比较小批量大小的时间更高,表明读取权重不再是瓶颈。

Starcoder预填充

  • 类似于Santacoder。
  • bigcode2/3略快于bigcode,但运行速度更快,内存耗尽更快。

2wnc66cl

2wnc66cl3#

感谢你的精彩调查。
"""为SantaCoder添加对cuda图的支持,至少支持解码。我已经向他们展示了如何使用动态形状(使用大量图)工作,它们为SantaCoder(以及Starcoder)带来了很大的加速,但由于静态KV缓存位置,它们在批处理连接/过滤时增加了复杂性。一个选项是始终使用相同的批处理大小(或一些预先确定的值,例如2的幂)进行解码,以避免每次过滤器都对数据进行昂贵的 Shuffle ,这应该没问题,因为(SantaCoder)解码延迟主要与批处理大小无关。"""
你能告诉我你在哪里为SantaCoder实现了具有动态大小的cuda图吗?我想知道它是如何实现的。

axkjgtzd

axkjgtzd4#

感谢jlamypoirier的精彩调查。

"""为SantaCoder添加对cuda图的支持,至少支持解码。我已经向他们展示了如何与动态形状(使用大量图)一起工作,它们为SantaCoder带来了很大的加速(对于Starcoder也有一定的加速),但由于静态KV缓存位置,它们在批处理连接/过滤方面增加了复杂性。一个选项是始终使用相同的批处理大小(或一些预先确定的值,例如2的幂)进行解码,这样可以避免每次过滤器都对数据进行昂贵的 Shuffle 操作,因为(SantaCoder)解码延迟主要与批处理大小无关。"""

你能告诉我你是如何为SantaCoder实现具有动态大小的cuda图的吗?我很好奇它是如何实现的。
抱歉回复晚了,你可以在我(混乱的)实现中找到它 . 注意,这个版本支持动态键长度,但不支持动态批处理大小。

ojsjcaue

ojsjcaue5#

非常棒的报告!我想问一下,序列长度是否表示$max_new_token$?当我将$max_new_token$设置为128时,我在starcoder上得到了相当高的延迟(大约4秒)。

mrfwxfqh

mrfwxfqh6#

@jlamypoirier 令人惊叹的报告!请问序列长度是否表示max_new_token?当我将max_new_token设置为128时,我在starcoder上获得了相当高的延迟(大约4秒)。生成一个令牌所需的时间。为了获得完整的时间,您需要添加预填充以获取上下文长度,并在范围(context_length, context_length + max_new_tokens)内生成。

0yycz8jy

0yycz8jy7#

这些都是很好的建议。有没有人发现这些方法已经应用到上游了?
如果没有,你的版本在哪里可以找到?
编辑:特别好奇关于
仅计算预填充中最后一个标记的模型头(除非我们需要它们来获取详细信息)。这可以节省一些时间,更重要的是避免内存瓶颈。

相关问题