Spaces:
Running
Running
Commit
·
20644bf
1
Parent(s):
e32d905
CUDA: fix race conditions FlashAttention kernels (llama/13438)
Browse files
ggml/src/ggml-cuda/fattn-mma-f16.cuh
CHANGED
|
@@ -874,6 +874,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 874 |
}
|
| 875 |
}
|
| 876 |
|
|
|
|
|
|
|
| 877 |
// Write back combined meta data:
|
| 878 |
#pragma unroll
|
| 879 |
for (int imeta = 0; imeta < nmeta; ++imeta) {
|
|
|
|
| 874 |
}
|
| 875 |
}
|
| 876 |
|
| 877 |
+
__syncthreads();
|
| 878 |
+
|
| 879 |
// Write back combined meta data:
|
| 880 |
#pragma unroll
|
| 881 |
for (int imeta = 0; imeta < nmeta; ++imeta) {
|
ggml/src/ggml-cuda/fattn-vec-f16.cuh
CHANGED
|
@@ -168,6 +168,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
| 168 |
for (int j = 0; j < ncols; ++j) {
|
| 169 |
KQ[j*D + tid] = -HALF_MAX_HALF;
|
| 170 |
}
|
|
|
|
| 171 |
|
| 172 |
half2 VKQ[ncols] = {{0.0f, 0.0f}};
|
| 173 |
|
|
|
|
| 168 |
for (int j = 0; j < ncols; ++j) {
|
| 169 |
KQ[j*D + tid] = -HALF_MAX_HALF;
|
| 170 |
}
|
| 171 |
+
__syncthreads();
|
| 172 |
|
| 173 |
half2 VKQ[ncols] = {{0.0f, 0.0f}};
|
| 174 |
|