Spaces:
Running
Running
Commit
·
2eca371
1
Parent(s):
5e088a7
CUDA: fix crash on large batch size for MoE models (llama/13384)
Browse files- ggml/src/ggml-cuda/getrows.cu +14 -12
ggml/src/ggml-cuda/getrows.cu
CHANGED
|
@@ -10,10 +10,11 @@ static __global__ void k_get_rows(
|
|
| 10 |
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
|
| 11 |
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
|
| 12 |
|
| 13 |
-
|
| 14 |
-
const int
|
| 15 |
-
const int
|
| 16 |
-
const int
|
|
|
|
| 17 |
|
| 18 |
if (i00 >= ne00) {
|
| 19 |
return;
|
|
@@ -46,10 +47,11 @@ static __global__ void k_get_rows_float(
|
|
| 46 |
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
|
| 47 |
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
|
| 48 |
|
| 49 |
-
|
| 50 |
-
const int
|
| 51 |
-
const int
|
| 52 |
-
const int
|
|
|
|
| 53 |
|
| 54 |
if (i00 >= ne00) {
|
| 55 |
return;
|
|
@@ -94,8 +96,8 @@ static void get_rows_cuda_q(
|
|
| 94 |
const size_t nb1, const size_t nb2, const size_t nb3,
|
| 95 |
cudaStream_t stream) {
|
| 96 |
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
| 97 |
-
const int
|
| 98 |
-
const dim3 block_nums(
|
| 99 |
|
| 100 |
// strides in elements
|
| 101 |
// const size_t s0 = nb0 / sizeof(dst_t);
|
|
@@ -127,8 +129,8 @@ static void get_rows_cuda_float(
|
|
| 127 |
const size_t nb1, const size_t nb2, const size_t nb3,
|
| 128 |
cudaStream_t stream) {
|
| 129 |
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
| 130 |
-
const int
|
| 131 |
-
const dim3 block_nums(
|
| 132 |
|
| 133 |
// strides in elements
|
| 134 |
// const size_t s0 = nb0 / sizeof(dst_t);
|
|
|
|
| 10 |
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
|
| 11 |
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
|
| 12 |
|
| 13 |
+
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
|
| 14 |
+
const int i00 = (blockIdx.y * blockDim.x + threadIdx.x)*2;
|
| 15 |
+
const int i10 = blockIdx.x;
|
| 16 |
+
const int i11 = blockIdx.z / ne12;
|
| 17 |
+
const int i12 = blockIdx.z % ne12;
|
| 18 |
|
| 19 |
if (i00 >= ne00) {
|
| 20 |
return;
|
|
|
|
| 47 |
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
|
| 48 |
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
|
| 49 |
|
| 50 |
+
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
|
| 51 |
+
const int i00 = blockIdx.y * blockDim.x + threadIdx.x;
|
| 52 |
+
const int i10 = blockIdx.x;
|
| 53 |
+
const int i11 = blockIdx.z / ne12;
|
| 54 |
+
const int i12 = blockIdx.z % ne12;
|
| 55 |
|
| 56 |
if (i00 >= ne00) {
|
| 57 |
return;
|
|
|
|
| 96 |
const size_t nb1, const size_t nb2, const size_t nb3,
|
| 97 |
cudaStream_t stream) {
|
| 98 |
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
| 99 |
+
const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
|
| 100 |
+
const dim3 block_nums(ne10, block_num_y, ne11*ne12);
|
| 101 |
|
| 102 |
// strides in elements
|
| 103 |
// const size_t s0 = nb0 / sizeof(dst_t);
|
|
|
|
| 129 |
const size_t nb1, const size_t nb2, const size_t nb3,
|
| 130 |
cudaStream_t stream) {
|
| 131 |
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
| 132 |
+
const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
|
| 133 |
+
const dim3 block_nums(ne10, block_num_y, ne11*ne12);
|
| 134 |
|
| 135 |
// strides in elements
|
| 136 |
// const size_t s0 = nb0 / sizeof(dst_t);
|