ggerganov commited on
Commit
4005bca
·
1 Parent(s): 8737d46

ggml : remove ggml_flash_attn and ggml_flash_ff (llama/7463)

Browse files
Files changed (2) hide show
  1. ggml.c +4 -672
  2. ggml.h +1 -17
ggml.c CHANGED
@@ -2670,9 +2670,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2670
  "ARGSORT",
2671
  "LEAKY_RELU",
2672
 
2673
- "FLASH_ATTN",
2674
  "FLASH_ATTN_EXT",
2675
- "FLASH_FF",
2676
  "FLASH_ATTN_BACK",
2677
  "SSM_CONV",
2678
  "SSM_SCAN",
@@ -2698,7 +2696,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2698
  "CROSS_ENTROPY_LOSS_BACK",
2699
  };
2700
 
2701
- static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
2702
 
2703
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2704
  "none",
@@ -2760,9 +2758,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2760
  "argsort(x)",
2761
  "leaky_relu(x)",
2762
 
2763
- "flash_attn(x)",
2764
  "flash_attn_ext(x)",
2765
- "flash_ff(x)",
2766
  "flash_attn_back(x)",
2767
  "ssm_conv(x)",
2768
  "ssm_scan(x)",
@@ -2788,7 +2784,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2788
  "cross_entropy_loss_back(x,y)",
2789
  };
2790
 
2791
- static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
2792
 
2793
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
2794
 
@@ -6948,38 +6944,6 @@ struct ggml_tensor * ggml_top_k(
6948
  return result;
6949
  }
6950
 
6951
- // ggml_flash_attn
6952
-
6953
- struct ggml_tensor * ggml_flash_attn(
6954
- struct ggml_context * ctx,
6955
- struct ggml_tensor * q,
6956
- struct ggml_tensor * k,
6957
- struct ggml_tensor * v,
6958
- bool masked) {
6959
- GGML_ASSERT(ggml_can_mul_mat(k, q));
6960
- // TODO: check if vT can be multiplied by (k*qT)
6961
-
6962
- bool is_node = false;
6963
-
6964
- if (q->grad || k->grad || v->grad) {
6965
- is_node = true;
6966
- }
6967
-
6968
- //struct ggml_tensor * result = ggml_dup_tensor(ctx, q);
6969
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne);
6970
-
6971
- int32_t t = masked ? 1 : 0;
6972
- ggml_set_op_params(result, &t, sizeof(t));
6973
-
6974
- result->op = GGML_OP_FLASH_ATTN;
6975
- result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6976
- result->src[0] = q;
6977
- result->src[1] = k;
6978
- result->src[2] = v;
6979
-
6980
- return result;
6981
- }
6982
-
6983
  // ggml_flash_attn_ext
6984
 
6985
  struct ggml_tensor * ggml_flash_attn_ext(
@@ -7039,38 +7003,6 @@ void ggml_flash_attn_ext_set_prec(
7039
  ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
7040
  }
7041
 
7042
- // ggml_flash_ff
7043
-
7044
- struct ggml_tensor * ggml_flash_ff(
7045
- struct ggml_context * ctx,
7046
- struct ggml_tensor * a,
7047
- struct ggml_tensor * b0,
7048
- struct ggml_tensor * b1,
7049
- struct ggml_tensor * c0,
7050
- struct ggml_tensor * c1) {
7051
- GGML_ASSERT(ggml_can_mul_mat(b0, a));
7052
- // TODO: more checks
7053
-
7054
- bool is_node = false;
7055
-
7056
- if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
7057
- is_node = true;
7058
- }
7059
-
7060
- //struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
7061
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, a->ne);
7062
-
7063
- result->op = GGML_OP_FLASH_FF;
7064
- result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7065
- result->src[0] = a;
7066
- result->src[1] = b0;
7067
- result->src[2] = b1;
7068
- result->src[3] = c0;
7069
- result->src[4] = c1;
7070
-
7071
- return result;
7072
- }
7073
-
7074
  // ggml_flash_attn_back
7075
 
7076
  struct ggml_tensor * ggml_flash_attn_back(
@@ -7080,6 +7012,8 @@ struct ggml_tensor * ggml_flash_attn_back(
7080
  struct ggml_tensor * v,
7081
  struct ggml_tensor * d,
7082
  bool masked) {
 
 
7083
  GGML_ASSERT(ggml_can_mul_mat(k, q));
7084
  // TODO: check if vT can be multiplied by (k*qT)
7085
 
@@ -15709,400 +15643,6 @@ static void ggml_compute_forward_argsort(
15709
  }
15710
  }
15711
 
15712
- // ggml_compute_forward_flash_attn
15713
-
15714
- static void ggml_compute_forward_flash_attn_f32(
15715
- const struct ggml_compute_params * params,
15716
- const bool masked,
15717
- struct ggml_tensor * dst) {
15718
-
15719
- const struct ggml_tensor * q = dst->src[0];
15720
- const struct ggml_tensor * k = dst->src[1];
15721
- const struct ggml_tensor * v = dst->src[2];
15722
-
15723
- int64_t t0 = ggml_perf_time_us();
15724
- UNUSED(t0);
15725
-
15726
- GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
15727
- GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
15728
- GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
15729
- GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
15730
- GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
15731
- GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
15732
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
15733
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
15734
-
15735
- const int ith = params->ith;
15736
- const int nth = params->nth;
15737
-
15738
- const int64_t D = neq0;
15739
- const int64_t N = neq1;
15740
- const int64_t P = nek1 - N;
15741
- const int64_t M = P + N;
15742
-
15743
- const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
15744
-
15745
- GGML_ASSERT(ne0 == D);
15746
- GGML_ASSERT(ne1 == N);
15747
- GGML_ASSERT(P >= 0);
15748
-
15749
- GGML_ASSERT(nbq0 == sizeof(float));
15750
- GGML_ASSERT(nbk0 == sizeof(float));
15751
- GGML_ASSERT(nbv0 == sizeof(float));
15752
-
15753
- GGML_ASSERT(neq0 == D);
15754
- GGML_ASSERT(nek0 == D);
15755
- GGML_ASSERT(nev1 == D);
15756
-
15757
- GGML_ASSERT(neq1 == N);
15758
- GGML_ASSERT(nek1 == N + P);
15759
- GGML_ASSERT(nev1 == D);
15760
-
15761
- // dst cannot be transposed or permuted
15762
- GGML_ASSERT(nb0 == sizeof(float));
15763
- GGML_ASSERT(nb0 <= nb1);
15764
- GGML_ASSERT(nb1 <= nb2);
15765
- GGML_ASSERT(nb2 <= nb3);
15766
-
15767
- if (params->type == GGML_TASK_TYPE_INIT) {
15768
- return;
15769
- }
15770
-
15771
- if (params->type == GGML_TASK_TYPE_FINALIZE) {
15772
- return;
15773
- }
15774
-
15775
- // parallelize by q rows using ggml_vec_dot_f32
15776
-
15777
- // total rows in q
15778
- const int nr = neq1*neq2*neq3;
15779
-
15780
- // rows per thread
15781
- const int dr = (nr + nth - 1)/nth;
15782
-
15783
- // row range for this thread
15784
- const int ir0 = dr*ith;
15785
- const int ir1 = MIN(ir0 + dr, nr);
15786
-
15787
- const float scale = 1.0f/sqrtf(D);
15788
-
15789
- //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
15790
-
15791
- for (int ir = ir0; ir < ir1; ++ir) {
15792
- // q indices
15793
- const int iq3 = ir/(neq2*neq1);
15794
- const int iq2 = (ir - iq3*neq2*neq1)/neq1;
15795
- const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
15796
-
15797
- float * S = (float *) params->wdata + ith*(Mup + CACHE_LINE_SIZE_F32);
15798
-
15799
- for (int i = M; i < Mup; ++i) {
15800
- S[i] = -INFINITY;
15801
- }
15802
-
15803
- const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
15804
- for (int64_t ic = 0; ic < masked_begin; ++ic) {
15805
- // k indices
15806
- const int ik3 = iq3;
15807
- const int ik2 = iq2 % nek2;
15808
- const int ik1 = ic;
15809
-
15810
- // S indices
15811
- const int i1 = ik1;
15812
-
15813
- ggml_vec_dot_f32(neq0,
15814
- S + i1, 0,
15815
- (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
15816
- (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
15817
- }
15818
-
15819
- // scale
15820
- ggml_vec_scale_f32(masked_begin, S, scale);
15821
-
15822
- for (int64_t i = masked_begin; i < M; i++) {
15823
- S[i] = -INFINITY;
15824
- }
15825
-
15826
- // softmax
15827
- // exclude known -INF S[..] values from max and loop
15828
- // dont forget to set their SW values to zero
15829
- {
15830
- float max = -INFINITY;
15831
- ggml_vec_max_f32(masked_begin, &max, S);
15832
-
15833
- ggml_float sum = 0.0;
15834
- {
15835
- #ifdef GGML_SOFT_MAX_ACCELERATE
15836
- max = -max;
15837
- vDSP_vsadd(S, 1, &max, S, 1, Mup);
15838
- vvexpf(S, S, &Mup);
15839
- ggml_vec_sum_f32(Mup, &sum, S);
15840
- #else
15841
- sum = ggml_vec_soft_max_f32(Mup, S, S, max);
15842
- #endif
15843
- }
15844
-
15845
- assert(sum > 0.0);
15846
-
15847
- sum = 1.0/sum;
15848
- ggml_vec_scale_f32(masked_begin, S, sum);
15849
-
15850
- #ifndef NDEBUG
15851
- for (int i = 0; i < masked_begin; ++i) {
15852
- assert(!isnan(S[i]));
15853
- assert(!isinf(S[i]));
15854
- }
15855
- #endif
15856
- }
15857
-
15858
- for (int64_t ic = 0; ic < nev1; ++ic) {
15859
- // dst indices
15860
- const int i1 = iq1;
15861
- const int i2 = iq2;
15862
- const int i3 = iq3;
15863
-
15864
- // v indices
15865
- const int iv2 = iq2 % nev2;
15866
- const int iv3 = iq3;
15867
-
15868
- ggml_vec_dot_f32(masked_begin,
15869
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
15870
- (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
15871
- S, 0, 1);
15872
- }
15873
- }
15874
- }
15875
-
15876
- static void ggml_compute_forward_flash_attn_f16(
15877
- const struct ggml_compute_params * params,
15878
- const bool masked,
15879
- struct ggml_tensor * dst) {
15880
-
15881
- const struct ggml_tensor * q = dst->src[0];
15882
- const struct ggml_tensor * k = dst->src[1];
15883
- const struct ggml_tensor * v = dst->src[2];
15884
-
15885
- int64_t t0 = ggml_perf_time_us();
15886
- UNUSED(t0);
15887
-
15888
- GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
15889
- GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
15890
- GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
15891
- GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
15892
- GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
15893
- GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
15894
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
15895
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
15896
-
15897
- const int ith = params->ith;
15898
- const int nth = params->nth;
15899
-
15900
- const int64_t D = neq0;
15901
- const int64_t N = neq1;
15902
- const int64_t P = nek1 - N;
15903
- const int64_t M = P + N;
15904
-
15905
- const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
15906
-
15907
- GGML_ASSERT(ne0 == D);
15908
- GGML_ASSERT(ne1 == N);
15909
- GGML_ASSERT(P >= 0);
15910
-
15911
- GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t));
15912
- GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
15913
- GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
15914
-
15915
- GGML_ASSERT(neq0 == D);
15916
- GGML_ASSERT(nek0 == D);
15917
- GGML_ASSERT(nev1 == D);
15918
-
15919
- GGML_ASSERT(neq1 == N);
15920
- GGML_ASSERT(nek1 == N + P);
15921
- GGML_ASSERT(nev1 == D);
15922
-
15923
- // dst cannot be transposed or permuted
15924
- GGML_ASSERT(nb0 == sizeof(float));
15925
- GGML_ASSERT(nb0 <= nb1);
15926
- GGML_ASSERT(nb1 <= nb2);
15927
- GGML_ASSERT(nb2 <= nb3);
15928
-
15929
- if (params->type == GGML_TASK_TYPE_INIT) {
15930
- return;
15931
- }
15932
-
15933
- if (params->type == GGML_TASK_TYPE_FINALIZE) {
15934
- return;
15935
- }
15936
-
15937
- // parallelize by q rows using ggml_vec_dot_f32
15938
-
15939
- // total rows in q
15940
- const int nr = neq1*neq2*neq3;
15941
-
15942
- // rows per thread
15943
- const int dr = (nr + nth - 1)/nth;
15944
-
15945
- // row range for this thread
15946
- const int ir0 = dr*ith;
15947
- const int ir1 = MIN(ir0 + dr, nr);
15948
-
15949
- const float scale = 1.0f/sqrtf(D);
15950
-
15951
- //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
15952
-
15953
- for (int ir = ir0; ir < ir1; ++ir) {
15954
- // q indices
15955
- const int iq3 = ir/(neq2*neq1);
15956
- const int iq2 = (ir - iq3*neq2*neq1)/neq1;
15957
- const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
15958
-
15959
- float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32);
15960
-
15961
- for (int i = M; i < Mup; ++i) {
15962
- S[i] = -INFINITY;
15963
- }
15964
-
15965
- if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) {
15966
- for (int64_t ic = 0; ic < nek1; ++ic) {
15967
- // k indices
15968
- const int ik3 = iq3;
15969
- const int ik2 = iq2 % nek2;
15970
- const int ik1 = ic;
15971
-
15972
- // S indices
15973
- const int i1 = ik1;
15974
-
15975
- ggml_vec_dot_f16(neq0,
15976
- S + i1, 0,
15977
- (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
15978
- (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
15979
- }
15980
- } else {
15981
- for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
15982
- // k indices
15983
- const int ik3 = iq3;
15984
- const int ik2 = iq2 % nek2;
15985
- const int ik1 = ic;
15986
-
15987
- // S indices
15988
- const int i1 = ik1;
15989
-
15990
- ggml_vec_dot_f16_unroll(neq0, nbk1,
15991
- S + i1,
15992
- ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
15993
- (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
15994
- }
15995
- }
15996
-
15997
- // scale
15998
- ggml_vec_scale_f32(nek1, S, scale);
15999
-
16000
- if (masked) {
16001
- for (int64_t i = P; i < M; i++) {
16002
- if (i > P + iq1) {
16003
- S[i] = -INFINITY;
16004
- }
16005
- }
16006
- }
16007
-
16008
- // softmax
16009
- // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
16010
- // dont forget to set their S values to zero
16011
- {
16012
- float max = -INFINITY;
16013
- ggml_vec_max_f32(M, &max, S);
16014
-
16015
- ggml_float sum = 0.0;
16016
- {
16017
- #ifdef GGML_SOFT_MAX_ACCELERATE
16018
- max = -max;
16019
- vDSP_vsadd(S, 1, &max, S, 1, Mup);
16020
- vvexpf(S, S, &Mup);
16021
- ggml_vec_sum_f32(Mup, &sum, S);
16022
- #else
16023
- sum = ggml_vec_soft_max_f32(Mup, S, S, max);
16024
- #endif
16025
- }
16026
-
16027
- assert(sum > 0.0);
16028
-
16029
- sum = 1.0/sum;
16030
- ggml_vec_scale_f32(M, S, sum);
16031
-
16032
- #ifndef NDEBUG
16033
- for (int i = 0; i < M; ++i) {
16034
- assert(!isnan(S[i]));
16035
- assert(!isinf(S[i]));
16036
- }
16037
- #endif
16038
- }
16039
-
16040
- ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup);
16041
-
16042
- for (int64_t i = 0; i < M; i++) {
16043
- S16[i] = GGML_FP32_TO_FP16(S[i]);
16044
- }
16045
-
16046
- // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16).
16047
- if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
16048
- for (int64_t ic = 0; ic < nev1; ++ic) {
16049
- // dst indices
16050
- const int i1 = iq1;
16051
- const int i2 = iq2;
16052
- const int i3 = iq3;
16053
-
16054
- // v indices
16055
- const int iv2 = iq2 % nev2;
16056
- const int iv3 = iq3;
16057
-
16058
- ggml_vec_dot_f16(nev0,
16059
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
16060
- (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
16061
- S16, 0, 1);
16062
- }
16063
- } else {
16064
- for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
16065
- // dst indices
16066
- const int i1 = iq1;
16067
- const int i2 = iq2;
16068
- const int i3 = iq3;
16069
-
16070
- // v indices
16071
- const int iv2 = iq2 % nev2;
16072
- const int iv3 = iq3;
16073
-
16074
- ggml_vec_dot_f16_unroll(nev0, nbv1,
16075
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
16076
- ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
16077
- S16);
16078
- }
16079
- }
16080
- }
16081
- }
16082
-
16083
- static void ggml_compute_forward_flash_attn(
16084
- const struct ggml_compute_params * params,
16085
- const bool masked,
16086
- struct ggml_tensor * dst) {
16087
-
16088
- const struct ggml_tensor * q = dst->src[0];
16089
-
16090
- switch (q->type) {
16091
- case GGML_TYPE_F16:
16092
- {
16093
- ggml_compute_forward_flash_attn_f16(params, masked, dst);
16094
- } break;
16095
- case GGML_TYPE_F32:
16096
- {
16097
- ggml_compute_forward_flash_attn_f32(params, masked, dst);
16098
- } break;
16099
- default:
16100
- {
16101
- GGML_ASSERT(false);
16102
- } break;
16103
- }
16104
- }
16105
-
16106
  // ggml_compute_forward_flash_attn_ext
16107
 
16108
  static void ggml_compute_forward_flash_attn_ext_f16(
@@ -16336,165 +15876,6 @@ static void ggml_compute_forward_flash_attn_ext(
16336
  }
16337
  }
16338
 
16339
- // ggml_compute_forward_flash_ff
16340
-
16341
- static void ggml_compute_forward_flash_ff_f16(
16342
- const struct ggml_compute_params * params,
16343
- struct ggml_tensor * dst) {
16344
-
16345
- const struct ggml_tensor * a = dst->src[0]; // F16
16346
- const struct ggml_tensor * b0 = dst->src[1]; // F16 fc_w
16347
- const struct ggml_tensor * b1 = dst->src[2]; // F32 fc_b
16348
- const struct ggml_tensor * c0 = dst->src[3]; // F16 proj_w
16349
- const struct ggml_tensor * c1 = dst->src[4]; // F32 proj_b
16350
-
16351
- int64_t t0 = ggml_perf_time_us();
16352
- UNUSED(t0);
16353
-
16354
- GGML_TENSOR_LOCALS(int64_t, nea, a, ne)
16355
- GGML_TENSOR_LOCALS(size_t, nba, a, nb)
16356
- GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne)
16357
- GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb)
16358
- GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne)
16359
- GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb)
16360
- GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne)
16361
- GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb)
16362
- GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne)
16363
- GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb)
16364
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
16365
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
16366
-
16367
- const int ith = params->ith;
16368
- const int nth = params->nth;
16369
-
16370
- const int64_t D = nea0;
16371
- //const int64_t N = nea1;
16372
- const int64_t M = neb01;
16373
-
16374
- GGML_ASSERT(ne0 == nea0);
16375
- GGML_ASSERT(ne1 == nea1);
16376
- GGML_ASSERT(ne2 == nea2);
16377
-
16378
- GGML_ASSERT(nba0 == sizeof(ggml_fp16_t));
16379
- GGML_ASSERT(nbb00 == sizeof(ggml_fp16_t));
16380
- GGML_ASSERT(nbb10 == sizeof(float));
16381
- GGML_ASSERT(nbc00 == sizeof(ggml_fp16_t));
16382
- GGML_ASSERT(nbc10 == sizeof(float));
16383
-
16384
- GGML_ASSERT(neb00 == D);
16385
- GGML_ASSERT(neb01 == M);
16386
- GGML_ASSERT(neb10 == M);
16387
- GGML_ASSERT(neb11 == 1);
16388
-
16389
- GGML_ASSERT(nec00 == M);
16390
- GGML_ASSERT(nec01 == D);
16391
- GGML_ASSERT(nec10 == D);
16392
- GGML_ASSERT(nec11 == 1);
16393
-
16394
- // dst cannot be transposed or permuted
16395
- GGML_ASSERT(nb0 == sizeof(float));
16396
- GGML_ASSERT(nb0 <= nb1);
16397
- GGML_ASSERT(nb1 <= nb2);
16398
- GGML_ASSERT(nb2 <= nb3);
16399
-
16400
- if (params->type == GGML_TASK_TYPE_INIT) {
16401
- return;
16402
- }
16403
-
16404
- if (params->type == GGML_TASK_TYPE_FINALIZE) {
16405
- return;
16406
- }
16407
-
16408
- // parallelize by a rows using ggml_vec_dot_f32
16409
-
16410
- // total rows in a
16411
- const int nr = nea1*nea2*nea3;
16412
-
16413
- // rows per thread
16414
- const int dr = (nr + nth - 1)/nth;
16415
-
16416
- // row range for this thread
16417
- const int ir0 = dr*ith;
16418
- const int ir1 = MIN(ir0 + dr, nr);
16419
-
16420
- for (int ir = ir0; ir < ir1; ++ir) {
16421
- // a indices
16422
- const int ia3 = ir/(nea2*nea1);
16423
- const int ia2 = (ir - ia3*nea2*nea1)/nea1;
16424
- const int ia1 = (ir - ia3*nea2*nea1 - ia2*nea1);
16425
-
16426
- float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32);
16427
-
16428
- for (int64_t ic = 0; ic < neb01; ++ic) {
16429
- // b0 indices
16430
- const int ib03 = ia3;
16431
- const int ib02 = ia2;
16432
- const int ib01 = ic;
16433
-
16434
- // S indices
16435
- const int i1 = ib01;
16436
-
16437
- ggml_vec_dot_f16(nea0,
16438
- S + i1, 0,
16439
- (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), 0,
16440
- (ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3)), 0, 1);
16441
- }
16442
-
16443
- ggml_vec_add_f32(neb01, S, S, (float *) b1->data);
16444
- //ggml_vec_gelu_f32(neb01, S, S);
16445
-
16446
- ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
16447
-
16448
- for (int64_t i = 0; i < M; i++) {
16449
- S16[i] = GGML_FP32_TO_FP16(S[i]);
16450
- }
16451
-
16452
- ggml_vec_gelu_f16(neb01, S16, S16);
16453
-
16454
- {
16455
- // dst indices
16456
- const int i1 = ia1;
16457
- const int i2 = ia2;
16458
- const int i3 = ia3;
16459
-
16460
- for (int64_t ic = 0; ic < nec01; ++ic) {
16461
-
16462
- ggml_vec_dot_f16(neb01,
16463
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
16464
- (ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)), 0,
16465
- S16, 0, 1);
16466
- }
16467
-
16468
- ggml_vec_add_f32(nec01,
16469
- (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
16470
- (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
16471
- (float *) c1->data);
16472
- }
16473
- }
16474
- }
16475
-
16476
- static void ggml_compute_forward_flash_ff(
16477
- const struct ggml_compute_params * params,
16478
- struct ggml_tensor * dst) {
16479
-
16480
- const struct ggml_tensor * b0 = dst->src[1];
16481
-
16482
- switch (b0->type) {
16483
- case GGML_TYPE_F16:
16484
- {
16485
- ggml_compute_forward_flash_ff_f16(params, dst);
16486
- } break;
16487
- case GGML_TYPE_F32:
16488
- {
16489
- GGML_ASSERT(false); // TODO
16490
- } break;
16491
- default:
16492
- {
16493
- GGML_ASSERT(false);
16494
- } break;
16495
- }
16496
- }
16497
-
16498
  // ggml_compute_forward_flash_attn_back
16499
 
16500
  static void ggml_compute_forward_flash_attn_back_f32(
@@ -18065,21 +17446,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
18065
  {
18066
  ggml_compute_forward_leaky_relu(params, tensor);
18067
  } break;
18068
- case GGML_OP_FLASH_ATTN:
18069
- {
18070
- const int32_t t = ggml_get_op_params_i32(tensor, 0);
18071
- GGML_ASSERT(t == 0 || t == 1);
18072
- const bool masked = t != 0;
18073
- ggml_compute_forward_flash_attn(params, masked, tensor);
18074
- } break;
18075
  case GGML_OP_FLASH_ATTN_EXT:
18076
  {
18077
  ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
18078
  } break;
18079
- case GGML_OP_FLASH_FF:
18080
- {
18081
- ggml_compute_forward_flash_ff(params, tensor);
18082
- } break;
18083
  case GGML_OP_FLASH_ATTN_BACK:
18084
  {
18085
  int32_t t = ggml_get_op_params_i32(tensor, 0);
@@ -19086,7 +18456,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
19086
  {
19087
  GGML_ASSERT(false); // TODO: not implemented
19088
  } break;
19089
- case GGML_OP_FLASH_ATTN:
19090
  case GGML_OP_FLASH_ATTN_EXT:
19091
  {
19092
  struct ggml_tensor * flash_grad = NULL;
@@ -19140,10 +18509,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
19140
  zero_table);
19141
  }
19142
  } break;
19143
- case GGML_OP_FLASH_FF:
19144
- {
19145
- GGML_ASSERT(false); // not supported
19146
- } break;
19147
  case GGML_OP_FLASH_ATTN_BACK:
19148
  {
19149
  GGML_ASSERT(false); // not supported
@@ -19830,15 +19195,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
19830
  {
19831
  n_tasks = n_threads;
19832
  } break;
19833
- case GGML_OP_FLASH_ATTN:
19834
  case GGML_OP_FLASH_ATTN_EXT:
19835
  {
19836
  n_tasks = n_threads;
19837
  } break;
19838
- case GGML_OP_FLASH_FF:
19839
- {
19840
- n_tasks = n_threads;
19841
- } break;
19842
  case GGML_OP_FLASH_ATTN_BACK:
19843
  {
19844
  n_tasks = n_threads;
@@ -20235,40 +19595,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
20235
  cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
20236
  cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
20237
  } break;
20238
- case GGML_OP_FLASH_ATTN:
20239
- {
20240
- const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
20241
-
20242
- if (node->src[1]->type == GGML_TYPE_F32) {
20243
- cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
20244
- cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
20245
- } else if (node->src[1]->type == GGML_TYPE_F16) {
20246
- cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
20247
- cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
20248
- } else if (node->src[1]->type == GGML_TYPE_BF16) {
20249
- cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
20250
- cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
20251
- }
20252
- } break;
20253
  case GGML_OP_FLASH_ATTN_EXT:
20254
  {
20255
  const int64_t ne00 = node->src[0]->ne[0]; // D
20256
 
20257
  cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
20258
  } break;
20259
- case GGML_OP_FLASH_FF:
20260
- {
20261
- if (node->src[1]->type == GGML_TYPE_F32) {
20262
- cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
20263
- cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
20264
- } else if (node->src[1]->type == GGML_TYPE_F16) {
20265
- cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
20266
- cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
20267
- } else if (node->src[1]->type == GGML_TYPE_BF16) {
20268
- cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
20269
- cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
20270
- }
20271
- } break;
20272
  case GGML_OP_FLASH_ATTN_BACK:
20273
  {
20274
  const int64_t D = node->src[0]->ne[0];
 
2670
  "ARGSORT",
2671
  "LEAKY_RELU",
2672
 
 
2673
  "FLASH_ATTN_EXT",
 
2674
  "FLASH_ATTN_BACK",
2675
  "SSM_CONV",
2676
  "SSM_SCAN",
 
2696
  "CROSS_ENTROPY_LOSS_BACK",
2697
  };
2698
 
2699
+ static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
2700
 
2701
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2702
  "none",
 
2758
  "argsort(x)",
2759
  "leaky_relu(x)",
2760
 
 
2761
  "flash_attn_ext(x)",
 
2762
  "flash_attn_back(x)",
2763
  "ssm_conv(x)",
2764
  "ssm_scan(x)",
 
2784
  "cross_entropy_loss_back(x,y)",
2785
  };
2786
 
2787
+ static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
2788
 
2789
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
2790
 
 
6944
  return result;
6945
  }
6946
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6947
  // ggml_flash_attn_ext
6948
 
6949
  struct ggml_tensor * ggml_flash_attn_ext(
 
7003
  ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
7004
  }
7005
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7006
  // ggml_flash_attn_back
7007
 
7008
  struct ggml_tensor * ggml_flash_attn_back(
 
7012
  struct ggml_tensor * v,
7013
  struct ggml_tensor * d,
7014
  bool masked) {
7015
+ GGML_ASSERT(false && "TODO: adapt to ggml_flash_attn_ext() changes");
7016
+
7017
  GGML_ASSERT(ggml_can_mul_mat(k, q));
7018
  // TODO: check if vT can be multiplied by (k*qT)
7019
 
 
15643
  }
15644
  }
15645
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15646
  // ggml_compute_forward_flash_attn_ext
15647
 
15648
  static void ggml_compute_forward_flash_attn_ext_f16(
 
15876
  }
15877
  }
15878
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15879
  // ggml_compute_forward_flash_attn_back
15880
 
15881
  static void ggml_compute_forward_flash_attn_back_f32(
 
17446
  {
17447
  ggml_compute_forward_leaky_relu(params, tensor);
17448
  } break;
 
 
 
 
 
 
 
17449
  case GGML_OP_FLASH_ATTN_EXT:
17450
  {
17451
  ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
17452
  } break;
 
 
 
 
17453
  case GGML_OP_FLASH_ATTN_BACK:
17454
  {
17455
  int32_t t = ggml_get_op_params_i32(tensor, 0);
 
18456
  {
18457
  GGML_ASSERT(false); // TODO: not implemented
18458
  } break;
 
18459
  case GGML_OP_FLASH_ATTN_EXT:
18460
  {
18461
  struct ggml_tensor * flash_grad = NULL;
 
18509
  zero_table);
18510
  }
18511
  } break;
 
 
 
 
18512
  case GGML_OP_FLASH_ATTN_BACK:
18513
  {
18514
  GGML_ASSERT(false); // not supported
 
19195
  {
19196
  n_tasks = n_threads;
19197
  } break;
 
19198
  case GGML_OP_FLASH_ATTN_EXT:
19199
  {
19200
  n_tasks = n_threads;
19201
  } break;
 
 
 
 
19202
  case GGML_OP_FLASH_ATTN_BACK:
19203
  {
19204
  n_tasks = n_threads;
 
19595
  cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
19596
  cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
19597
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19598
  case GGML_OP_FLASH_ATTN_EXT:
19599
  {
19600
  const int64_t ne00 = node->src[0]->ne[0]; // D
19601
 
19602
  cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
19603
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
19604
  case GGML_OP_FLASH_ATTN_BACK:
19605
  {
19606
  const int64_t D = node->src[0]->ne[0];
ggml.h CHANGED
@@ -481,9 +481,7 @@ extern "C" {
481
  GGML_OP_ARGSORT,
482
  GGML_OP_LEAKY_RELU,
483
 
484
- GGML_OP_FLASH_ATTN,
485
  GGML_OP_FLASH_ATTN_EXT,
486
- GGML_OP_FLASH_FF,
487
  GGML_OP_FLASH_ATTN_BACK,
488
  GGML_OP_SSM_CONV,
489
  GGML_OP_SSM_SCAN,
@@ -1761,13 +1759,6 @@ extern "C" {
1761
  struct ggml_tensor * a,
1762
  int k);
1763
 
1764
- GGML_API struct ggml_tensor * ggml_flash_attn(
1765
- struct ggml_context * ctx,
1766
- struct ggml_tensor * q,
1767
- struct ggml_tensor * k,
1768
- struct ggml_tensor * v,
1769
- bool masked);
1770
-
1771
  #define GGML_KQ_MASK_PAD 32
1772
 
1773
  // q: [n_embd, n_batch, n_head, 1]
@@ -1788,6 +1779,7 @@ extern "C" {
1788
  struct ggml_tensor * a,
1789
  enum ggml_prec prec);
1790
 
 
1791
  GGML_API struct ggml_tensor * ggml_flash_attn_back(
1792
  struct ggml_context * ctx,
1793
  struct ggml_tensor * q,
@@ -1796,14 +1788,6 @@ extern "C" {
1796
  struct ggml_tensor * d,
1797
  bool masked);
1798
 
1799
- GGML_API struct ggml_tensor * ggml_flash_ff(
1800
- struct ggml_context * ctx,
1801
- struct ggml_tensor * a,
1802
- struct ggml_tensor * b0,
1803
- struct ggml_tensor * b1,
1804
- struct ggml_tensor * c0,
1805
- struct ggml_tensor * c1);
1806
-
1807
  GGML_API struct ggml_tensor * ggml_ssm_conv(
1808
  struct ggml_context * ctx,
1809
  struct ggml_tensor * s,
 
481
  GGML_OP_ARGSORT,
482
  GGML_OP_LEAKY_RELU,
483
 
 
484
  GGML_OP_FLASH_ATTN_EXT,
 
485
  GGML_OP_FLASH_ATTN_BACK,
486
  GGML_OP_SSM_CONV,
487
  GGML_OP_SSM_SCAN,
 
1759
  struct ggml_tensor * a,
1760
  int k);
1761
 
 
 
 
 
 
 
 
1762
  #define GGML_KQ_MASK_PAD 32
1763
 
1764
  // q: [n_embd, n_batch, n_head, 1]
 
1779
  struct ggml_tensor * a,
1780
  enum ggml_prec prec);
1781
 
1782
+ // TODO: needs to be adapted to ggml_flash_attn_ext
1783
  GGML_API struct ggml_tensor * ggml_flash_attn_back(
1784
  struct ggml_context * ctx,
1785
  struct ggml_tensor * q,
 
1788
  struct ggml_tensor * d,
1789
  bool masked);
1790
 
 
 
 
 
 
 
 
 
1791
  GGML_API struct ggml_tensor * ggml_ssm_conv(
1792
  struct ggml_context * ctx,
1793
  struct ggml_tensor * s,