简介
myFA分支完成了任务要求。通过分块 + 融合实现了简化(忽略了sqrt(dk)和safe softma)的fa1,并通过ISPC内部函数实现向量化。
在ISPC向量化中我使用了两个方法:
- 最内层循环通过foreach实现并行
- 最后一步的矩阵乘法分块并行。
Br, Bc
已知默认值d = 32
RTX 3080的L1缓存大小为 M = 128 KB(每SM) = 32K floats
Br = min(ceil(M / 4d), d) = min(256, 32) = 32
Bc = ceil(M / 4d) = 256
TILE_SIZE
TILE_SIZE的意义更多在于并行而非分块。
我的matmulRow函数参考了https://github.com/BienBoy/cs149gpt/blob/main/module.ispc。
matmulRow函数处理一行与一个矩阵的乘积。naive的方法是一行逐次与多列点乘。此处的并行优化是,TILE_SIZE个列为一组,每次处理一行与多个列的点乘。TILE_SIZE的大小只是防止占用空间过大。
然而此方法只适用于A * B而不适用于A * B_t,实践证明matmul_TRow是一个负优化,可能是因为TILE_SIZE个行反而降低了局部性。
//A 是原矩阵的一行,长度为N,B 是 NxK的矩阵,C 是 1行K列的输出
export inline void matmulRow(
uniform float A[],
uniform float B[],
uniform float C[],
uniform int N,
uniform int K
){
uniform float sumTile[TILE_SIZE];
for (uniform int k = 0; k < K; k += TILE_SIZE) {
foreach (ki = 0 ... TILE_SIZE) {
sumTile[ki] = 0.0f;
}
for (uniform int n = 0; n < N; n++) {
foreach(ki = k ... min(k + TILE_SIZE, K)) {
sumTile[ki - k] += A[n] * B[n * K + ki];
}
}
foreach (ki = k ... min(k + TILE_SIZE, K)) {
C[ki] = sumTile[ki - k];
}
}
}
// part1中可尝试替换的部分,实践证明可行
// matmulRow(QK_t + i * N, V, O + i * d, N, d);
for (uniform int j = 0; j < d; j++) {
float sum = 0;
foreach(k = 0 ... N) {
float val1 = QK_t[i * N + k];
float val2 = V[b * H * N * d + h * N * d + k * d + j];
sum += val1 * val2;
}
uniform float sum1 = reduce_add(sum);
O[i * d + j] = sum1;
}
// A 是原矩阵的一行,长度为K,B 是NxK的矩阵,C 是 1行N列的输出
export inline void matmul_TRow(
uniform float A[],
uniform float B[],
uniform float C[],
uniform int N,
uniform int K
){
uniform float sumTile[TILE_SIZE];
for (uniform int n = 0; n < N; n += TILE_SIZE) {
foreach (ni = 0 ... TILE_SIZE) {
sumTile[ni] = 0.0f;
}
for (uniform int k = 0; k < K; k++) {
foreach(ni = n ... min(n + TILE_SIZE, N)) {
sumTile[ni - n] += A[k] * B[ni * K + k];
}
}
foreach (ni = n ... min(n + TILE_SIZE, N)) {
C[ni] = sumTile[ni - n];
}
}
}
// part1中可尝试替换的部分,实践证明是负优化
// QK^t
// matmul_TRow(Q + i * d, K, QK_t + i * N, N, d);
for (uniform int j = 0; j < N; j++) {
float val = 0;
foreach (k = 0 ... d){
float val1 = Q[i * d + k];
float val2 = K[j * d + k];
val += val1 * val2;
}
uniform float sum = reduce_add(val);
if (programIndex == 0) {
QK_t[i * N + j] = sum;
}
}
此尝试实际上证明了,让多个列并行效果会更好,RTX 3080 的 Cache Line大小通常为 64 字节,即16个float大小,理论上讲TILE_SIZE = 16刚好。
但是实践过程中发现32效果最好,而d又恰好为32,因此此处无需分块,三层循环的矩阵乘法代码可简化为:对于A * B,可对第二层foreach并行,对于A * B_t,对最内层直接并行。即使当d!=32需要分块时,也无需单独写函数,只需要foreach外面加上一层分块的循环即可。