CPU+GPU,模型KV緩存壓力被緩解了。
來(lái)自CMU、華盛頓大學(xué)、Meta AI的研究人員提出MagicPIG,通過(guò)在CPU上使用LSH(局部敏感哈希)采樣技術(shù),有效克服了GPU內(nèi)存容量限制的問(wèn)題。
與僅使用GPU的注意力機(jī)制相比,MagicPIG在各種情況下提高了1.76~4.99倍的解碼吞吐量,并在檢索和推理任務(wù)中實(shí)現(xiàn)了更高的下游準(zhǔn)確率,優(yōu)于Quest等現(xiàn)有技術(shù)。
概括而言,這項(xiàng)研究主要貢獻(xiàn)有兩點(diǎn):
1、相比于其他的稀疏注意力(Sparse Attention),MagicPIG基于采樣/估計(jì)而非搜索,提升了推理質(zhì)量。
2、研究把解碼階段注意力模塊的計(jì)算和哈希表卸載到CPU上,探索了異構(gòu)計(jì)算的可能性,并且提升了吞吐量,有望降低實(shí)際模型部署成本。
下面具體來(lái)看。
KV緩存限制了GPU高效利用在長(zhǎng)上下文大模型(LLM)的推理過(guò)程中,KV緩存(Key-Value Cache)成為關(guān)鍵瓶頸。KV緩存主要用于存儲(chǔ)中間的注意力鍵和值,從而避免重復(fù)計(jì)算。
然而,其顯存占用隨著批量大小和序列長(zhǎng)度的線性增長(zhǎng)而迅速增加,這嚴(yán)重限制了GPU的批量處理能力,導(dǎo)致計(jì)算資源無(wú)法被充分利用。
以NVIDIA A100-40GB GPU為例,在處理Llama-3.1-8B模型且上下文長(zhǎng)度為128k時(shí),僅支持單個(gè)請(qǐng)求,且近一半的解碼時(shí)間都消耗在訪問(wèn)KV緩存上,GPU利用率明顯不足。
此外,推理過(guò)程中采用的一些策略,如多樣性生成(Best-of-N)和長(zhǎng)鏈?zhǔn)酵评恚↙ong Chain-of-Thoughts),會(huì)進(jìn)一步增加生成的Token數(shù)量,加劇顯存壓力,導(dǎo)致推理效率進(jìn)一步下降。
TopK Attention的問(wèn)題眾所周知,注意力機(jī)制本質(zhì)上具有稀疏性,因此動(dòng)態(tài)稀疏注意力和基于TopK的近似方法得到了廣泛研究。
然而,這些方法往往伴隨著顯著的質(zhì)量下降問(wèn)題。
目前已有的KV緩存壓縮技術(shù),如Quest、H2O和Loki,主要通過(guò)篩選出KV緩存中注意力得分最高的子集來(lái)提高效率。然而,盡管這些方法在實(shí)踐中表現(xiàn)出一定的效果,基于TopK的注意力依然是一種存在偏差的近似方法,且缺乏理論上的嚴(yán)格保障。
這種不足限制了其在高精度場(chǎng)景中的廣泛應(yīng)用。
下圖顯示,即使是精確的TopK注意力機(jī)制也會(huì)導(dǎo)致顯著的估計(jì)誤差和下游任務(wù)性能下降。
這一問(wèn)題在需要高上下文利用率的復(fù)雜任務(wù)中尤為突出,例如聚合任務(wù)、常用詞提取(CWE)、高頻詞提取(FWE)以及邏輯推理任務(wù)。在這些場(chǎng)景中,基于TopK近似方法的性能下降尤其嚴(yán)重。
以下幾點(diǎn)觀察揭示了為何TopK注意力機(jī)制無(wú)法始終有效工作。
這些觀察不僅解釋了注意力機(jī)制的行為,還可能對(duì)模型訓(xùn)練具有重要意義:
1、首個(gè)輸入token(注意力匯聚點(diǎn),sink)的隱藏狀態(tài)(包括但不限于鍵和值狀態(tài))幾乎不隨輸入變化而改變。(見左圖, 在采樣的輸入中,其最小相似度均高于0.99)
2、鍵狀態(tài)的中心方向在不同輸入句子中保持穩(wěn)定。(見中圖, 相似度均高于0.9)
3、鍵狀態(tài)的中心與匯聚點(diǎn)token的鍵狀態(tài)幾乎相反。(見右圖, -0.9至-0.8之間)
這些現(xiàn)象為理解注意力機(jī)制提供了新的視角,同時(shí)也表明傳統(tǒng)的TopK近似方法在某些場(chǎng)景下可能存在局限性。
為了解決這一問(wèn)題,研究提出了一種基于采樣而非搜索TopK鍵值緩存的新方法。
算法:基于采樣的注意力估計(jì)與僅依賴注意力分?jǐn)?shù)最高的鍵值對(duì)相比,融入基礎(chǔ)分布信息可以顯著提高估計(jì)的準(zhǔn)確性。
研究將這一問(wèn)題視為采樣中的偏差校正問(wèn)題。在生物學(xué)、社會(huì)學(xué)和機(jī)器學(xué)習(xí)等領(lǐng)域,無(wú)偏且高效的采樣技術(shù)已被廣泛研究,并具有堅(jiān)實(shí)的理論保障。
如圖所示,基于注意力分?jǐn)?shù)按比例進(jìn)行采樣(即所謂的Oracle Sampling,研究把注意力模塊的輸出看成value向量的期望值,對(duì)應(yīng)的分布是注意力得分)相比于傳統(tǒng)的TopK選擇方法,其估計(jì)誤差要小得多,最多可降低4倍。
這表明采樣技術(shù)在注意力近似中的潛力。
從注意力得分中采樣,在實(shí)際中不可行。重要性采樣(Importance Sampling)允許從一個(gè)已知分布中抽取樣本1,2,…,B,來(lái)估計(jì)未知分布的期望。
最終的輸出由下式給出:
重要性采樣要求和的峰值對(duì)應(yīng)以降低估計(jì)方差,為此,研究使用局部敏感哈希(LSH) 來(lái)生成采樣概率。
需要指出的是,因?yàn)榇嬖赟oftmax(注意力得分需要?dú)w一化), 所以研究實(shí)際上試圖近似的是自歸一化重要性采樣。
系統(tǒng):將注意力計(jì)算和哈希表放在CPU上除了精度下降的問(wèn)題外,受限的GPU顯存容量也限制了現(xiàn)有動(dòng)態(tài)KV緩存壓縮方法(如Quest和Loki)在許多場(chǎng)景中的適用性。
與此同時(shí),像DeepSpeed-Zero-Inference和FastDecode這樣的技術(shù)展示了將KV緩存和注意力計(jì)算卸載到CPU上的潛力。
CPU的內(nèi)存帶寬大約是GPU顯存帶寬的10%-20%,這引出了一個(gè)自然的問(wèn)題:
能否在不犧牲精度的前提下,將注意力計(jì)算中的內(nèi)存訪問(wèn)量減少10倍?
通過(guò)利用采樣算法,例如MagicPIG中基于LSH(局部敏感哈希)的采樣技術(shù)進(jìn)行注意力估計(jì),研究大幅降低了內(nèi)存訪問(wèn)量。這種方法等效地提升了CPU的內(nèi)存帶寬,使得在維持精度的情況下實(shí)現(xiàn)高效的注意力計(jì)算。
論文的系統(tǒng)設(shè)計(jì)擴(kuò)展了以往的工作,將大語(yǔ)言模型(LLM)的解碼分為以下四個(gè)部分:
參數(shù)計(jì)算:包括所有線性投均在GPU上運(yùn)行。
注意力計(jì)算:涉及公式,該部分在CPU上運(yùn)行。
隨機(jī)投影:在生成過(guò)程中,對(duì)于每個(gè)執(zhí)行K x L次隨機(jī)投影以生成哈希碼。由于所有注意力頭可以共享相同的隨機(jī)投影器,內(nèi)存開銷較。ㄔ趯(shí)際實(shí)現(xiàn)中約為400KB)。實(shí)驗(yàn)中K=9或10,而L為數(shù)百,因此該步驟主要受計(jì)算限制,放置在GPU上運(yùn)行。
檢索:需要在L個(gè)哈希表中查找q的哈希碼。這部分計(jì)算開銷非常輕量,但預(yù)構(gòu)建的哈希表占用的內(nèi)存較大,因此更適合放置在CPU上運(yùn)行。通過(guò)上述任務(wù)分區(qū),可以支持更大規(guī)模的K和L哈希表,而無(wú)需擔(dān)心哈希碼計(jì)算和哈希表存儲(chǔ)的開銷。
實(shí)驗(yàn)研究從準(zhǔn)確率和推理速度兩個(gè)方面來(lái)評(píng)估MagicPIG系統(tǒng)的能力。
圖片中的百分比為實(shí)際采樣的KV cache的數(shù)量,對(duì)于MagicPIG而言,K10L150≈2%, K10L170≈2.5%。
長(zhǎng)文本RULER以Llama-3.1-8B-Instruct為例,MagicPIG在檢索和推理任務(wù)中比Quest(稀疏注意力的SOTA基線)實(shí)現(xiàn)了更高的下游準(zhǔn)確率。
推理速度和吞吐量在L20 + Intel 8563C上測(cè)試吞吐量,MagicPIG與僅使用GPU的注意力機(jī)制相比,在各種情況下提高了1.76~4.99倍的解碼吞吐量。
整體而言,MagicPIG是將經(jīng)典的哈希算法和高維向量估計(jì)用到LLM解碼上的嘗試。
接下來(lái),研究將支持更加高效的局部敏感哈希算法,并希望進(jìn)一步降低LLM部署成本,探索異構(gòu)計(jì)算的可能性。