lhez commited on
Commit
4434043
·
1 Parent(s): 90cefa0

opencl : broadcast for soft_max (llama/14510)

Browse files
ggml/src/ggml-opencl/ggml-opencl.cpp CHANGED
@@ -5763,19 +5763,31 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
5763
 
5764
  cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
5765
 
5766
- const int ne00 = src0 ? src0->ne[0] : 0;
5767
- const int ne01 = src0 ? src0->ne[1] : 0;
5768
- const int ne02 = src0 ? src0->ne[2] : 0;
5769
- const int ne03 = src0 ? src0->ne[3] : 0;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5770
 
5771
  float scale, max_bias;
5772
  memcpy(&scale, dst->op_params + 0, sizeof(float));
5773
  memcpy(&max_bias, dst->op_params + 1, sizeof(float));
5774
 
5775
- const int nrows_x = ggml_nrows(src0);
5776
- const int nrows_y = src0->ne[1];
5777
-
5778
- const int n_head = nrows_x/nrows_y;
5779
  const int n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
5780
 
5781
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@@ -5820,13 +5832,22 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
5820
  CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
5821
  CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
5822
  CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
5823
- CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
5824
- CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
5825
- CL_CHECK(clSetKernelArg(kernel, 9, sizeof(float), &scale));
5826
- CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float), &max_bias));
5827
- CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &m0));
5828
- CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &m1));
5829
- CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &n_head_log2));
 
 
 
 
 
 
 
 
 
5830
 
5831
  size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
5832
  size_t local_work_size[] = {(size_t)nth, 1, 1};
 
5763
 
5764
  cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
5765
 
5766
+ const int ne00 = src0->ne[0];
5767
+ const int ne01 = src0->ne[1];
5768
+ const int ne02 = src0->ne[2];
5769
+ const int ne03 = src0->ne[3];
5770
+
5771
+ const cl_long nb01 = src0->nb[1];
5772
+ const cl_long nb02 = src0->nb[2];
5773
+ const cl_long nb03 = src0->nb[3];
5774
+
5775
+ const int ne12 = src1 ? src1->ne[2] : 0;
5776
+ const int ne13 = src1 ? src1->ne[3] : 0;
5777
+
5778
+ const cl_long nb11 = src1 ? src1->nb[1] : 0;
5779
+ const cl_long nb12 = src1 ? src1->nb[2] : 0;
5780
+ const cl_long nb13 = src1 ? src1->nb[3] : 0;
5781
+
5782
+ const cl_long nb1 = dst->nb[1];
5783
+ const cl_long nb2 = dst->nb[2];
5784
+ const cl_long nb3 = dst->nb[3];
5785
 
5786
  float scale, max_bias;
5787
  memcpy(&scale, dst->op_params + 0, sizeof(float));
5788
  memcpy(&max_bias, dst->op_params + 1, sizeof(float));
5789
 
5790
+ const int n_head = src0->ne[2];
 
 
 
5791
  const int n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
5792
 
5793
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
 
5832
  CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
5833
  CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
5834
  CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
5835
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
5836
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
5837
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
5838
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
5839
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne13));
5840
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
5841
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
5842
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13));
5843
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb1));
5844
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb2));
5845
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb3));
5846
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float), &scale));
5847
+ CL_CHECK(clSetKernelArg(kernel, 19, sizeof(float), &max_bias));
5848
+ CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &m0));
5849
+ CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &m1));
5850
+ CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &n_head_log2));
5851
 
5852
  size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
5853
  size_t local_work_size[] = {(size_t)nth, 1, 1};
ggml/src/ggml-opencl/kernels/softmax_4_f16.cl CHANGED
@@ -22,32 +22,45 @@
22
  REQD_SUBGROUP_SIZE_64
23
  #endif
24
  kernel void kernel_soft_max_4_f16(
25
- global float * src0,
26
  ulong offset0,
27
- global half * src1,
28
  ulong offset1,
29
- global float * dst,
30
  ulong offsetd,
31
  int ne00,
32
- int ne01,
33
- int ne02,
 
 
 
 
 
 
 
 
 
34
  float scale,
35
  float max_bias,
36
  float m0,
37
  float m1,
38
  int n_head_log2
39
  ) {
40
- src0 = (global float *)((global char *)src0 + offset0);
41
- src1 = (global half *)((global char *)src1 + offset1);
42
- dst = (global float *)((global char *)dst + offsetd);
43
 
44
  int i03 = get_group_id(2);
45
  int i02 = get_group_id(1);
46
  int i01 = get_group_id(0);
47
 
48
- global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
49
- global half4 * pmask = (global char *)src1 != (global char *)src0 ? (global half4 *)(src1 + i01*ne00) : 0;
50
- global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
 
 
 
 
51
 
52
  float slope = 1.0f;
53
 
 
22
  REQD_SUBGROUP_SIZE_64
23
  #endif
24
  kernel void kernel_soft_max_4_f16(
25
+ global char * src0,
26
  ulong offset0,
27
+ global char * src1,
28
  ulong offset1,
29
+ global char * dst,
30
  ulong offsetd,
31
  int ne00,
32
+ ulong nb01,
33
+ ulong nb02,
34
+ ulong nb03,
35
+ int ne12,
36
+ int ne13,
37
+ ulong nb11,
38
+ ulong nb12,
39
+ ulong nb13,
40
+ ulong nb1,
41
+ ulong nb2,
42
+ ulong nb3,
43
  float scale,
44
  float max_bias,
45
  float m0,
46
  float m1,
47
  int n_head_log2
48
  ) {
49
+ src0 = src0 + offset0;
50
+ src1 = src1 + offset1;
51
+ dst = dst + offsetd;
52
 
53
  int i03 = get_group_id(2);
54
  int i02 = get_group_id(1);
55
  int i01 = get_group_id(0);
56
 
57
+ int i13 = i03%ne13;
58
+ int i12 = i02%ne12;
59
+ int i11 = i01;
60
+
61
+ global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
62
+ global half4 * pmask = src1 != src0 ? (global half4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
63
+ global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
64
 
65
  float slope = 1.0f;
66
 
ggml/src/ggml-opencl/kernels/softmax_4_f32.cl CHANGED
@@ -22,32 +22,45 @@
22
  REQD_SUBGROUP_SIZE_64
23
  #endif
24
  kernel void kernel_soft_max_4(
25
- global float * src0,
26
  ulong offset0,
27
- global float * src1,
28
  ulong offset1,
29
- global float * dst,
30
  ulong offsetd,
31
  int ne00,
32
- int ne01,
33
- int ne02,
 
 
 
 
 
 
 
 
 
34
  float scale,
35
  float max_bias,
36
  float m0,
37
  float m1,
38
  int n_head_log2
39
  ) {
40
- src0 = (global float*)((global char*)src0 + offset0);
41
- src1 = (global float*)((global char*)src1 + offset1);
42
- dst = (global float*)((global char*)dst + offsetd);
43
 
44
  int i03 = get_group_id(2);
45
  int i02 = get_group_id(1);
46
  int i01 = get_group_id(0);
47
 
48
- global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
49
- global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i01*ne00) : 0;
50
- global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
 
 
 
 
51
 
52
  float slope = 1.0f;
53
 
 
22
  REQD_SUBGROUP_SIZE_64
23
  #endif
24
  kernel void kernel_soft_max_4(
25
+ global char * src0,
26
  ulong offset0,
27
+ global char * src1,
28
  ulong offset1,
29
+ global char * dst,
30
  ulong offsetd,
31
  int ne00,
32
+ ulong nb01,
33
+ ulong nb02,
34
+ ulong nb03,
35
+ int ne12,
36
+ int ne13,
37
+ ulong nb11,
38
+ ulong nb12,
39
+ ulong nb13,
40
+ ulong nb1,
41
+ ulong nb2,
42
+ ulong nb3,
43
  float scale,
44
  float max_bias,
45
  float m0,
46
  float m1,
47
  int n_head_log2
48
  ) {
49
+ src0 = src0 + offset0;
50
+ src1 = src1 + offset1;
51
+ dst = dst + offsetd;
52
 
53
  int i03 = get_group_id(2);
54
  int i02 = get_group_id(1);
55
  int i01 = get_group_id(0);
56
 
57
+ int i13 = i03%ne13;
58
+ int i12 = i02%ne12;
59
+ int i11 = i01;
60
+
61
+ global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
62
+ global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
63
+ global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
64
 
65
  float slope = 1.0f;
66
 
ggml/src/ggml-opencl/kernels/softmax_f16.cl CHANGED
@@ -22,32 +22,45 @@
22
  REQD_SUBGROUP_SIZE_64
23
  #endif
24
  kernel void kernel_soft_max_f16(
25
- global float * src0,
26
  ulong offset0,
27
- global half * src1,
28
  ulong offset1,
29
- global float * dst,
30
  ulong offsetd,
31
  int ne00,
32
- int ne01,
33
- int ne02,
 
 
 
 
 
 
 
 
 
34
  float scale,
35
  float max_bias,
36
  float m0,
37
  float m1,
38
  int n_head_log2
39
  ) {
40
- src0 = (global float *)((global char *)src0 + offset0);
41
- src1 = (global half *)((global char *)src1 + offset1);
42
- dst = (global float *)((global char *)dst + offsetd);
43
 
44
  int i03 = get_group_id(2);
45
  int i02 = get_group_id(1);
46
  int i01 = get_group_id(0);
47
 
48
- global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
49
- global half * pmask = (global char *)src1 != (global char *)src0 ? src1 + i01*ne00 : 0;
50
- global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
 
 
 
 
51
 
52
  float slope = 1.0f;
53
 
 
22
  REQD_SUBGROUP_SIZE_64
23
  #endif
24
  kernel void kernel_soft_max_f16(
25
+ global char * src0,
26
  ulong offset0,
27
+ global char * src1,
28
  ulong offset1,
29
+ global char * dst,
30
  ulong offsetd,
31
  int ne00,
32
+ ulong nb01,
33
+ ulong nb02,
34
+ ulong nb03,
35
+ int ne12,
36
+ int ne13,
37
+ ulong nb11,
38
+ ulong nb12,
39
+ ulong nb13,
40
+ ulong nb1,
41
+ ulong nb2,
42
+ ulong nb3,
43
  float scale,
44
  float max_bias,
45
  float m0,
46
  float m1,
47
  int n_head_log2
48
  ) {
49
+ src0 = src0 + offset0;
50
+ src1 = src1 + offset1;
51
+ dst = dst + offsetd;
52
 
53
  int i03 = get_group_id(2);
54
  int i02 = get_group_id(1);
55
  int i01 = get_group_id(0);
56
 
57
+ int i13 = i03%ne13;
58
+ int i12 = i02%ne12;
59
+ int i11 = i01;
60
+
61
+ global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
62
+ global half * pmask = src1 != src0 ? (global half *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
63
+ global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
64
 
65
  float slope = 1.0f;
66
 
ggml/src/ggml-opencl/kernels/softmax_f32.cl CHANGED
@@ -22,32 +22,45 @@
22
  REQD_SUBGROUP_SIZE_64
23
  #endif
24
  kernel void kernel_soft_max(
25
- global float * src0,
26
  ulong offset0,
27
- global float * src1,
28
  ulong offset1,
29
- global float * dst,
30
  ulong offsetd,
31
  int ne00,
32
- int ne01,
33
- int ne02,
 
 
 
 
 
 
 
 
 
34
  float scale,
35
  float max_bias,
36
  float m0,
37
  float m1,
38
  int n_head_log2
39
  ) {
40
- src0 = (global float*)((global char*)src0 + offset0);
41
- src1 = (global float*)((global char*)src1 + offset1);
42
- dst = (global float*)((global char*)dst + offsetd);
43
 
44
  int i03 = get_group_id(2);
45
  int i02 = get_group_id(1);
46
  int i01 = get_group_id(0);
47
 
48
- global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
49
- global float * pmask = src1 != src0 ? src1 + i01*ne00 : 0;
50
- global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
 
 
 
 
51
 
52
  float slope = 1.0f;
53
 
 
22
  REQD_SUBGROUP_SIZE_64
23
  #endif
24
  kernel void kernel_soft_max(
25
+ global char * src0,
26
  ulong offset0,
27
+ global char * src1,
28
  ulong offset1,
29
+ global char * dst,
30
  ulong offsetd,
31
  int ne00,
32
+ ulong nb01,
33
+ ulong nb02,
34
+ ulong nb03,
35
+ int ne12,
36
+ int ne13,
37
+ ulong nb11,
38
+ ulong nb12,
39
+ ulong nb13,
40
+ ulong nb1,
41
+ ulong nb2,
42
+ ulong nb3,
43
  float scale,
44
  float max_bias,
45
  float m0,
46
  float m1,
47
  int n_head_log2
48
  ) {
49
+ src0 = src0 + offset0;
50
+ src1 = src1 + offset1;
51
+ dst = dst + offsetd;
52
 
53
  int i03 = get_group_id(2);
54
  int i02 = get_group_id(1);
55
  int i01 = get_group_id(0);
56
 
57
+ int i13 = i03%ne13;
58
+ int i12 = i02%ne12;
59
+ int i11 = i01;
60
+
61
+ global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
62
+ global float * pmask = src1 != src0 ? (global float *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
63
+ global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
64
 
65
  float slope = 1.0f;
66