am17an commited on
Commit
05351ac
·
1 Parent(s): 47e02a8

CUDA: add softmax broadcast (llama/14475)

Browse files

* CUDA: add softmax broadcast

* Pass by const ref

* Review: Use blockDims for indexing, remove designated initializers

* Add TODO for noncontigous input/output

ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -3329,13 +3329,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3329
  case GGML_OP_DIAG_MASK_INF:
3330
  return true;
3331
  case GGML_OP_SOFT_MAX:
3332
- // TODO: support batching
3333
- if (op->src[0]->ne[3] != 1) {
3334
- return false;
3335
- }
3336
- // TODO: support broadcast
3337
- // ref: https://github.com/ggml-org/llama.cpp/pull/14435
3338
- return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
3339
  case GGML_OP_SOFT_MAX_BACK: {
3340
  float max_bias = 0.0f;
3341
  memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
 
3329
  case GGML_OP_DIAG_MASK_INF:
3330
  return true;
3331
  case GGML_OP_SOFT_MAX:
3332
+ return true;
 
 
 
 
 
 
3333
  case GGML_OP_SOFT_MAX_BACK: {
3334
  float max_bias = 0.0f;
3335
  memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
ggml/src/ggml-cuda/softmax.cu CHANGED
@@ -13,6 +13,29 @@ __device__ float __forceinline__ t2f32<half>(half val) {
13
  return __half2float(val);
14
  }
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  // When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
17
  // As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
18
  #ifdef __clang__
@@ -21,16 +44,24 @@ __device__ float __forceinline__ t2f32<half>(half val) {
21
  #endif // __clang__
22
  template <bool use_shared, int ncols_template, int block_size_template, typename T>
23
  static __global__ void soft_max_f32(
24
- const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
25
- const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
26
- const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
27
 
28
  const int tid = threadIdx.x;
29
- const int rowx = blockIdx.x;
30
- const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
 
 
 
 
 
 
 
 
 
31
 
32
  x += int64_t(rowx)*ncols;
33
- mask += int64_t(rowy)*ncols * (mask != nullptr);
34
  dst += int64_t(rowx)*ncols;
35
 
36
  const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
@@ -38,7 +69,7 @@ static __global__ void soft_max_f32(
38
  const int warp_id = threadIdx.x / WARP_SIZE;
39
  const int lane_id = threadIdx.x % WARP_SIZE;
40
 
41
- const float slope = get_alibi_slope(max_bias, rowx/nrows_y, n_head_log2, m0, m1);
42
 
43
  extern __shared__ float data_soft_max_f32[];
44
  float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
@@ -55,7 +86,7 @@ static __global__ void soft_max_f32(
55
  break;
56
  }
57
 
58
- const float val = x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
59
 
60
  vals[col] = val;
61
  max_val = max(max_val, val);
@@ -151,63 +182,60 @@ static __global__ void soft_max_back_f32(
151
  }
152
 
153
  template<typename T>
154
- static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
155
  int nth = WARP_SIZE;
 
 
156
  while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
157
  const dim3 block_dims(nth, 1, 1);
158
- const dim3 block_nums(nrows_x, 1, 1);
159
  const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
160
  static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
161
 
162
- const uint32_t n_head = nrows_x/nrows_y;
163
- const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
164
-
165
- const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
166
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
167
 
168
  // FIXME: this limit could be raised by ~2-4x on Ampere or newer
169
  if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
170
  switch (ncols_x) {
171
  case 32:
172
  soft_max_f32<true, 32, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
173
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
174
  break;
175
  case 64:
176
  soft_max_f32<true, 64, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
177
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
178
  break;
179
  case 128:
180
  soft_max_f32<true, 128, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
181
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
182
  break;
183
  case 256:
184
  soft_max_f32<true, 256, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
185
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
186
  break;
187
  case 512:
188
  soft_max_f32<true, 512, 512><<<block_nums, block_dims, nbytes_shared, stream>>>
189
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
190
  break;
191
  case 1024:
192
  soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
193
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
194
  break;
195
  case 2048:
196
  soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
197
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
198
  break;
199
  case 4096:
200
  soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
201
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
202
  break;
203
  default:
204
  soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>
205
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
206
  break;
207
  }
208
  } else {
209
  const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
210
- soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
211
  }
212
  }
213
 
@@ -235,10 +263,11 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
235
 
236
  GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
237
 
238
- const int64_t ne00 = src0->ne[0];
239
  const int64_t nrows_x = ggml_nrows(src0);
240
  const int64_t nrows_y = src0->ne[1];
241
 
 
 
242
  float scale = 1.0f;
243
  float max_bias = 0.0f;
244
 
@@ -247,10 +276,44 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
247
 
248
  const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  if (use_f16) {
251
- soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
252
  } else {
253
- soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
254
  }
255
  }
256
 
 
13
  return __half2float(val);
14
  }
15
 
16
+ struct soft_max_params {
17
+
18
+ int64_t nheads;
19
+ uint32_t n_head_log2;
20
+ int64_t ncols;
21
+ int64_t nrows_x;
22
+ int64_t nrows_y;
23
+ int64_t ne00;
24
+ int64_t ne01;
25
+ int64_t ne02;
26
+ int64_t ne03;
27
+ int64_t nb11;
28
+ int64_t nb12;
29
+ int64_t nb13;
30
+
31
+ int64_t ne12;
32
+ int64_t ne13;
33
+ float scale;
34
+ float max_bias;
35
+ float m0;
36
+ float m1;
37
+ };
38
+
39
  // When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
40
  // As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
41
  #ifdef __clang__
 
44
  #endif // __clang__
45
  template <bool use_shared, int ncols_template, int block_size_template, typename T>
46
  static __global__ void soft_max_f32(
47
+ const float * x, const T * mask, float * dst, const soft_max_params p) {
48
+ const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
 
49
 
50
  const int tid = threadIdx.x;
51
+
52
+ const int64_t i03 = blockIdx.z;
53
+ const int64_t i02 = blockIdx.y;
54
+ const int64_t i01 = blockIdx.x;
55
+
56
+ //TODO: noncontigous inputs/outputs
57
+ const int rowx = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
58
+
59
+ const int64_t i11 = i01;
60
+ const int64_t i12 = i02 % p.ne12;
61
+ const int64_t i13 = i03 % p.ne13;
62
 
63
  x += int64_t(rowx)*ncols;
64
+ mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr);
65
  dst += int64_t(rowx)*ncols;
66
 
67
  const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
 
69
  const int warp_id = threadIdx.x / WARP_SIZE;
70
  const int lane_id = threadIdx.x % WARP_SIZE;
71
 
72
+ const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1);
73
 
74
  extern __shared__ float data_soft_max_f32[];
75
  float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
 
86
  break;
87
  }
88
 
89
+ const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
90
 
91
  vals[col] = val;
92
  max_val = max(max_val, val);
 
182
  }
183
 
184
  template<typename T>
185
+ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
186
  int nth = WARP_SIZE;
187
+ const int64_t ncols_x = params.ncols;
188
+
189
  while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
190
  const dim3 block_dims(nth, 1, 1);
191
+ const dim3 block_nums(params.ne01, params.ne02, params.ne03);
192
  const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
193
  static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
194
 
 
 
 
 
 
195
 
196
  // FIXME: this limit could be raised by ~2-4x on Ampere or newer
197
  if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
198
  switch (ncols_x) {
199
  case 32:
200
  soft_max_f32<true, 32, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
201
+ (x, mask, dst, params);
202
  break;
203
  case 64:
204
  soft_max_f32<true, 64, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
205
+ (x, mask, dst, params);
206
  break;
207
  case 128:
208
  soft_max_f32<true, 128, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
209
+ (x, mask, dst, params);
210
  break;
211
  case 256:
212
  soft_max_f32<true, 256, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
213
+ (x, mask, dst, params);
214
  break;
215
  case 512:
216
  soft_max_f32<true, 512, 512><<<block_nums, block_dims, nbytes_shared, stream>>>
217
+ (x, mask, dst, params);
218
  break;
219
  case 1024:
220
  soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
221
+ (x, mask, dst, params);
222
  break;
223
  case 2048:
224
  soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
225
+ (x, mask, dst, params);
226
  break;
227
  case 4096:
228
  soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
229
+ (x, mask, dst, params);
230
  break;
231
  default:
232
  soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>
233
+ (x, mask, dst, params);
234
  break;
235
  }
236
  } else {
237
  const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
238
+ soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params);
239
  }
240
  }
241
 
 
263
 
264
  GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
265
 
 
266
  const int64_t nrows_x = ggml_nrows(src0);
267
  const int64_t nrows_y = src0->ne[1];
268
 
269
+ const int64_t ne00 = src0->ne[0];
270
+
271
  float scale = 1.0f;
272
  float max_bias = 0.0f;
273
 
 
276
 
277
  const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
278
 
279
+ const int64_t nb11 = src1 ? src1->nb[1] : 1;
280
+ const int64_t nb12 = src1 ? src1->nb[2] : 1;
281
+ const int64_t nb13 = src1 ? src1->nb[3] : 1;
282
+
283
+ const int64_t ne12 = src1 ? src1->ne[2] : 1;
284
+ const int64_t ne13 = src1 ? src1->ne[3] : 1;
285
+
286
+ const uint32_t n_head = src0->ne[2];
287
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
288
+
289
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
290
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
291
+
292
+
293
+ soft_max_params params = {};
294
+ params.nheads = src0->ne[2];
295
+ params.n_head_log2 = n_head_log2;
296
+ params.ncols = ne00;
297
+ params.nrows_x = nrows_x;
298
+ params.nrows_y = nrows_y;
299
+ params.ne00 = src0->ne[0];
300
+ params.ne01 = src0->ne[1];
301
+ params.ne02 = src0->ne[2];
302
+ params.ne03 = src0->ne[3];
303
+ params.nb11 = nb11;
304
+ params.nb12 = nb12;
305
+ params.nb13 = nb13;
306
+ params.ne12 = ne12;
307
+ params.ne13 = ne13;
308
+ params.scale = scale;
309
+ params.max_bias = max_bias;
310
+ params.m0 = m0;
311
+ params.m1 = m1;
312
+
313
  if (use_f16) {
314
+ soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, params, stream);
315
  } else {
316
+ soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, params, stream);
317
  }
318
  }
319