您的浏览器禁用了JavaScript(一种计算机语言,用以实现您与网页的交互),请解除该禁用,或者联系我们。[NVIDIA]:使用Cutlass/CuTe进行HSTU注意力开发与优化 - 发现报告

使用Cutlass/CuTe进行HSTU注意力开发与优化

信息技术2025-05-30-NVIDIA乐***
使用Cutlass/CuTe进行HSTU注意力开发与优化

HSTUblockIntroductionAttentionoptimizationRab &dRabCute PracticalQAAgenda ••••• BecauseAssuming thatAt this moment 𝑈𝑋,𝑉𝑋,𝑄𝑋,𝐾𝑋=𝑆𝑝𝑙𝑖𝑡𝜙1𝑓1𝑋𝐴𝑋𝑉𝑋=𝜙2𝑄𝑋𝐾𝑋𝑇+𝑟𝑎𝑏𝑃,𝑇𝑉𝑋𝑌𝑋=𝑓2𝑁𝑜𝑟𝑚𝐴𝑋𝑉𝑋⊙𝑈𝑋In theforward pass,computations arefusedinto a few kernels, withthe result of𝑸𝑲𝑻storedin registersas anintermediate result.Similarly, in thebackward pass,computations can alsobefused into a fewkernelsandstored inintermediate results. registersas In thebackward pass,computations also need to befused into a few kernels. Sincethe computation of gradientsdepends on the result of𝑄𝐾𝑇thisresult is recalculatedandstored as an intermediateresult in registers.Similarly, in thebackward pass,computations can alsobefused into a fewkernelsandstored inintermediate results.In theforward pass,computations arefusedinto a few kernels, withthe result of𝑸𝑲𝑻storedin registersas anintermediate result.𝑈𝑋,𝑉𝑋,𝑄𝑋,𝐾𝑋=𝑆𝑝𝑙𝑖𝑡𝜙1𝑓1𝑋𝐴𝑋𝑉𝑋=𝜙2𝑄𝑋𝐾𝑋𝑇+𝑟𝑎𝑏𝑃,𝑇𝑉𝑋𝑌𝑋=𝑓2𝑁𝑜𝑟𝑚𝐴𝑋𝑉𝑋⊙𝑈𝑋 WhereLetThere are four equivalent computational formulas: 𝑑𝑥𝑖=𝜓𝑑𝑦𝑖𝛾𝑖−1𝑁෍𝑘𝑑𝑦𝑘𝛾𝑘−𝜓2𝑥𝑖−𝜇𝑑𝑥𝑖=𝜓𝑑𝑦𝑖𝛾𝑖−1𝑁෍𝑘𝑑𝑦𝑘𝛾𝑘−𝜓(𝑦𝑖−𝛽𝑖𝛾𝑖𝑑𝑥𝑖=𝜓𝑑𝑦𝑖𝛾𝑖−1𝑁෍𝑘𝑑𝑦𝑘𝛾𝑘−𝜓𝑥𝑖−𝜇𝑑𝑥𝑖=𝜓𝑑𝑦𝑖𝛾𝑖−1𝑁෍𝑘𝑑𝑦𝑘𝛾𝑘−𝑦𝑖−𝛽𝛾𝑖 Adjust the computationflow (order ofinstructions) andsynchronization timingso that computationand memory accesshide each other'slatency. Unroll computationor memory accessinstructions to exploitinter-instructionparallelism,enhancing pipelineutilization. Utilize high-speed on-chipmemory (such as SMEM or REG)to cache data, thereby reducingthe number of memory accessinstructions and alleviatingpressure on various levels ofmemory access. Reduce thenumber ofother non-essentialinstructions. By employing a tilingmethod for the attentionmatrix, memory access andcomputation are performedonly on the meaningful"tiles" in the mask matrix.This effectively conservesunnecessary computationpower and bandwidth. HSTUblockIntroductionAttentionoptimizationRab &dRabCute PracticalAgenda ••• Overview HSTUAttention kernel•MuchBetter performance than OAI-triton•Open source in NVIDIA/RecSys-exampleandpytorch/FBGEMM•Support various customized masks applied in GenerativeRecommendation•Support experimental FP8 attention for training •Support Paged Attention for inferencecontextual + group target mask Optimization OverviewBlock TilingShapeWarp TilingShape Fused AttentionForward1.CTATile is divided alongSeq_Q. If divided alongSeq_K, then the calculating O willFori in Q_tileacc_o*=scale acc_o={0}For jin K/V tileacc_s=GEMM(Q_tile[i],K_tile[j])acc_s+=Bias_tile[i,j]acc_s*=alphaP= silu(P)P= mask(P)acc_o+=GEMM(P,V_tile[j]) introduce one global atomicAdd. TilingSizeinAttentionUsage of register and shared memory under different tile sizeand params1.The larger Dim, the more registers, as the size oftOrVandacc_ois related to DIM2.The largerkBlockM, the more registers, as the size oftSrQandacc_s/acc_ois related tokBlockM3.The largerkBlockN, the more registers, as the size oftSrKandacc_sis related tokBlockN1.The largerkBlockM/N and DIM, the more shared memory, as the tile size is related to shared memory.Besides, after G2S loadsQtile, we could S2R to keepQtileresident in the registerorleaveitinshared memory.•S2R: increase register pressure,reduce shared memory pressure and reduce the LDSM instruction number. •register:•shared memory TilingSizeinAttentionInstruction number under different tile size and paramsEach CTA needs perform G2S and S2R on complete K and V, meaning that K and V existredundant G2S and S2R.•The smaller BlockM, the more blocks, leading to more redundant G2S and S2R operations for K and V, more LDSM/LDGSTS instruction.•IfQ_tilestays in shared memory, we need to repeatedly S2R forQ_tilewhendoQKGEMM. This means that the smaller the BlockN, the greater the number of tiles in theseq_Kdirection, which leads to more repetitions of the S2R operation and more LDSM instructions. Thus: TileMMAinAttentionFWDThe influence of Tile Size onTiledMMAW0W1W2W3Q_tileW0/1/2/3W0W1W2W3K_tileS_tileusingTiledMma=TiledMMA<MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>,Layout<Shape<_4, _1,_1>>,Tile<Int<4*16>, _16,_16>>;BlockM=64, BlockN=64, DIM=64loopKdimensionLDSM(0)MMA(0)LDSM(1)MMA(1)LDSM(2)MMA(2)LDSM(3)MMA(3)•Minimum granularity to satisfy MMA instructions•Use LDSM x4 whenever possible_16_16 Set an appropriate TileShape for TileMMA_16 TileMMAinAttentionFWDThe influence of different WarpArrangementsonTiledMMAWW0W1W2W3W0W1W2W3Q_tileK_tileP_tileTiledMMA<MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>,usingTiledMmaQK=TiledMMA<MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>,,Layout<_1, Int<4>,_1>>,Tile<_16, Int<4*16>,_16>>;usingTiledMmaPV=TiledMMA<typename Base::MMA_Atom_Arch,Layout<_1, _1,Int<4>>,Tile<_16, _16,Int<4*16>>>;For example, regarding the two options of left and right•When Seq_Q is small, it is usually recommended to use the rightoption, as it will launch more blocks to fully utilize the GPU.