Flash Attention - 2

CS149GPT实现一个简化的FAv1

简介

cs149

我的仓库

myFA分支完成了任务要求。通过分块 + 融合实现了简化(忽略了sqrt(dk)和safe softma)的fa1,并通过ISPC内部函数实现向量化。

在ISPC向量化中我使用了两个方法:

  1. 最内层循环通过foreach实现并行
  2. 最后一步的矩阵乘法分块并行。

Br, Bc

已知默认值d = 32
RTX 3080的L1缓存大小为 M = 128 KB(每SM) = 32K floats
Br = min(ceil(M / 4d), d) = min(256, 32) = 32
Bc = ceil(M / 4d) = 256

TILE_SIZE

TILE_SIZE的意义更多在于并行而非分块。
我的matmulRow函数参考了https://github.com/BienBoy/cs149gpt/blob/main/module.ispc。
matmulRow函数处理一行与一个矩阵的乘积。naive的方法是一行逐次与多列点乘。此处的并行优化是,TILE_SIZE个列为一组,每次处理一行与多个列的点乘。TILE_SIZE的大小只是防止占用空间过大。
然而此方法只适用于A * B而不适用于A * B_t,实践证明matmul_TRow是一个负优化,可能是因为TILE_SIZE个行反而降低了局部性。

//A 是原矩阵的一行,长度为N,B 是 NxK的矩阵,C 是 1行K列的输出
export inline void matmulRow(
    uniform float A[], 
    uniform float B[], 
    uniform float C[], 
    uniform int N, 
    uniform int K
){
    uniform float sumTile[TILE_SIZE];
    for (uniform int k = 0; k < K; k += TILE_SIZE) {
        foreach (ki = 0 ... TILE_SIZE) {
            sumTile[ki] = 0.0f;
        }
        for (uniform int n = 0; n < N; n++) {
            foreach(ki = k ... min(k + TILE_SIZE, K)) {
                sumTile[ki - k] += A[n] * B[n * K + ki];
            }
        } 
        foreach (ki = k ... min(k + TILE_SIZE, K)) {
            C[ki] = sumTile[ki - k];
        }
    }
}
// part1中可尝试替换的部分,实践证明可行
// matmulRow(QK_t + i * N, V, O + i * d, N, d);
for (uniform int j = 0; j < d; j++) {
    float sum = 0;
    foreach(k = 0 ... N) {
        float val1 = QK_t[i * N + k];
        float val2 = V[b * H * N * d + h * N * d + k * d + j];
        sum += val1 * val2;
    }
    uniform float sum1 = reduce_add(sum);
    O[i * d + j] = sum1;
}

// A 是原矩阵的一行,长度为K,B 是NxK的矩阵,C 是 1行N列的输出
export inline void matmul_TRow(
    uniform float A[], 
    uniform float B[], 
    uniform float C[], 
    uniform int N, 
    uniform int K
){
    uniform float sumTile[TILE_SIZE];

    for (uniform int n = 0; n < N; n += TILE_SIZE) {
        foreach (ni = 0 ... TILE_SIZE) {
            sumTile[ni] = 0.0f;
        }
        for (uniform int k = 0; k < K; k++) {
            foreach(ni = n ... min(n + TILE_SIZE, N)) {
                sumTile[ni - n] += A[k] * B[ni * K + k];
            }
        }
        foreach (ni = n ... min(n + TILE_SIZE, N)) {
            C[ni] = sumTile[ni - n];
        }
    }
}
// part1中可尝试替换的部分,实践证明是负优化
// QK^t
// matmul_TRow(Q + i * d, K, QK_t + i * N, N, d);
for (uniform int j = 0; j < N; j++) {
    float val = 0;
    foreach (k = 0 ... d){
        float val1 = Q[i * d + k];
        float val2 = K[j * d + k];
        val += val1 * val2;
    }
    uniform float sum = reduce_add(val);
    if (programIndex == 0) {
        QK_t[i * N + j] = sum;
    }
}

此尝试实际上证明了,让多个列并行效果会更好,RTX 3080 的 Cache Line大小通常为 64 字节,即16个float大小,理论上讲TILE_SIZE = 16刚好。
但是实践过程中发现32效果最好,而d又恰好为32,因此此处无需分块,三层循环的矩阵乘法代码可简化为:对于A * B,可对第二层foreach并行,对于A * B_t,对最内层直接并行。即使当d!=32需要分块时,也无需单独写函数,只需要foreach外面加上一层分块的循环即可。