HSTUblockIntroductionAttentionoptimizationRab &dRabCute PracticalQAAgenda
•••••
BecauseAssuming thatAt this moment
𝑈𝑋,𝑉𝑋,𝑄𝑋,𝐾𝑋=𝑆𝑝𝑙𝑖𝑡𝜙1𝑓1𝑋𝐴𝑋𝑉𝑋=𝜙2𝑄𝑋𝐾𝑋𝑇+𝑟𝑎𝑏𝑃,𝑇𝑉𝑋𝑌𝑋=𝑓2𝑁𝑜𝑟𝑚𝐴𝑋𝑉𝑋⊙𝑈𝑋In theforward pass,computations arefusedinto a few kernels, withthe result of𝑸𝑲𝑻storedin registersas anintermediate result.Similarly, in thebackward pass,computations can alsobefused into a fewkernelsandstored inintermediate results.
registersas
In thebackward pass,computations also need to befused into a few kernels. Sincethe computation of gradientsdepends on the result of𝑄𝐾𝑇thisresult is recalculatedandstored as an intermediateresult in registers.Similarly, in thebackward pass,computations can alsobefused into a fewkernelsandstored inintermediate results.In theforward pass,computations arefusedinto a few kernels, withthe result of𝑸𝑲𝑻storedin registersas anintermediate result.𝑈𝑋,𝑉𝑋,𝑄𝑋,𝐾𝑋=𝑆𝑝𝑙𝑖𝑡𝜙1𝑓1𝑋𝐴𝑋𝑉𝑋=𝜙2𝑄𝑋𝐾𝑋𝑇+𝑟𝑎𝑏𝑃,𝑇𝑉𝑋𝑌𝑋=𝑓2𝑁𝑜𝑟𝑚𝐴𝑋𝑉𝑋⊙𝑈𝑋
WhereLetThere are four equivalent computational formulas:
𝑑𝑥𝑖=𝜓𝑑𝑦𝑖𝛾𝑖−1𝑁𝑘𝑑𝑦𝑘𝛾𝑘−𝜓2𝑥𝑖−𝜇𝑑𝑥𝑖=𝜓𝑑𝑦𝑖𝛾𝑖−1𝑁𝑘𝑑𝑦𝑘𝛾𝑘−𝜓(𝑦𝑖−𝛽𝑖𝛾𝑖𝑑𝑥𝑖=𝜓𝑑𝑦𝑖𝛾𝑖−1𝑁𝑘𝑑𝑦𝑘𝛾𝑘−𝜓𝑥𝑖−𝜇𝑑𝑥𝑖=𝜓𝑑𝑦𝑖𝛾𝑖−1𝑁𝑘𝑑𝑦𝑘𝛾𝑘−𝑦𝑖−𝛽𝛾𝑖
Adjust the computationflow (order ofinstructions) andsynchronization timingso that computationand memory accesshide each other'slatency.
Unroll computationor memory accessinstructions to exploitinter-instructionparallelism,enhancing pipelineutilization.
Utilize high-speed on-chipmemory (such as SMEM or REG)to cache data, thereby reducingthe number of memory accessinstructions and alleviatingpressure on various levels ofmemory access.
Reduce thenumber ofother non-essentialinstructions.
By employing a tilingmethod for the attentionmatrix, memory access andcomputation are performedonly on the meaningful"tiles" in the mask matrix.This effectively conservesunnecessary computationpower and bandwidth.
HSTUblockIntroductionAttentionoptimizationRab &dRabCute PracticalAgenda
•••
Overview HSTUAttention kernel•MuchBetter performance than OAI-triton•Open source in NVIDIA/RecSys-exampleandpytorch/FBGEMM•Support various customized masks applied in GenerativeRecommendation•Support experimental FP8 attention for training
•Support Paged Attention for inferencecontextual + group target mask
Optimization OverviewBlock TilingShapeWarp TilingShape
Fused AttentionForward1.CTATile is divided alongSeq_Q. If divided alongSeq_K, then the calculating O willFori in Q_tileacc_o*=scale
acc_o={0}For jin K/V tileacc_s=GEMM(Q_tile[i],K_tile[j])acc_s+=Bias_tile[i,j]acc_s*=alphaP= silu(P)P= mask(P)acc_o+=GEMM(P,V_tile[j])
introduce one global atomicAdd.
TilingSizeinAttentionUsage of register and shared memory under different tile sizeand params1.The larger Dim, the more registers, as the size oftOrVandacc_ois related to DIM2.The largerkBlockM, the more registers, as the size oftSrQandacc_s/acc_ois related tokBlockM3.The largerkBlockN, the more registers, as the size oftSrKandacc_sis related tokBlockN1.The largerkBlockM/N and DIM, the more shared memory, as the tile size is related to shared memory.Besides, after G2S loadsQtile, we could S2R to keepQtileresident in the registerorleaveitinshared memory.•S2R: increase register pressure,reduce shared memory pressure and reduce the LDSM instruction number.
•register:•shared memory
TilingSizeinAttentionInstruction number under different tile size and paramsEach CTA needs perform G2S and S2R on complete K and V, meaning that K and V existredundant G2S and S2R.•The smaller BlockM, the more blocks, leading to more redundant G2S and S2R operations for K and V, more LDSM/LDGSTS instruction.•IfQ_tilestays in shared memory, we need to repeatedly S2R forQ_tilewhendoQKGEMM. This means that the smaller the BlockN, the greater the number of tiles in theseq_Kdirection, which leads to more repetitions of the S2R operation and more LDSM instructions.
Thus:
TileMMAinAttentionFWDThe influence of Tile Size onTiledMMAW0W1W2W3Q_tileW0/1/2/3W0W1W2W3K_tileS_tileusingTiledMma=TiledMMA