Spaces:
Running
Running
ggml : remove ggml_flash_attn and ggml_flash_ff (llama/7463)
Browse files
ggml.c
CHANGED
|
@@ -2670,9 +2670,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
| 2670 |
"ARGSORT",
|
| 2671 |
"LEAKY_RELU",
|
| 2672 |
|
| 2673 |
-
"FLASH_ATTN",
|
| 2674 |
"FLASH_ATTN_EXT",
|
| 2675 |
-
"FLASH_FF",
|
| 2676 |
"FLASH_ATTN_BACK",
|
| 2677 |
"SSM_CONV",
|
| 2678 |
"SSM_SCAN",
|
|
@@ -2698,7 +2696,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
| 2698 |
"CROSS_ENTROPY_LOSS_BACK",
|
| 2699 |
};
|
| 2700 |
|
| 2701 |
-
static_assert(GGML_OP_COUNT ==
|
| 2702 |
|
| 2703 |
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
| 2704 |
"none",
|
|
@@ -2760,9 +2758,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
| 2760 |
"argsort(x)",
|
| 2761 |
"leaky_relu(x)",
|
| 2762 |
|
| 2763 |
-
"flash_attn(x)",
|
| 2764 |
"flash_attn_ext(x)",
|
| 2765 |
-
"flash_ff(x)",
|
| 2766 |
"flash_attn_back(x)",
|
| 2767 |
"ssm_conv(x)",
|
| 2768 |
"ssm_scan(x)",
|
|
@@ -2788,7 +2784,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
| 2788 |
"cross_entropy_loss_back(x,y)",
|
| 2789 |
};
|
| 2790 |
|
| 2791 |
-
static_assert(GGML_OP_COUNT ==
|
| 2792 |
|
| 2793 |
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
| 2794 |
|
|
@@ -6948,38 +6944,6 @@ struct ggml_tensor * ggml_top_k(
|
|
| 6948 |
return result;
|
| 6949 |
}
|
| 6950 |
|
| 6951 |
-
// ggml_flash_attn
|
| 6952 |
-
|
| 6953 |
-
struct ggml_tensor * ggml_flash_attn(
|
| 6954 |
-
struct ggml_context * ctx,
|
| 6955 |
-
struct ggml_tensor * q,
|
| 6956 |
-
struct ggml_tensor * k,
|
| 6957 |
-
struct ggml_tensor * v,
|
| 6958 |
-
bool masked) {
|
| 6959 |
-
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
| 6960 |
-
// TODO: check if vT can be multiplied by (k*qT)
|
| 6961 |
-
|
| 6962 |
-
bool is_node = false;
|
| 6963 |
-
|
| 6964 |
-
if (q->grad || k->grad || v->grad) {
|
| 6965 |
-
is_node = true;
|
| 6966 |
-
}
|
| 6967 |
-
|
| 6968 |
-
//struct ggml_tensor * result = ggml_dup_tensor(ctx, q);
|
| 6969 |
-
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne);
|
| 6970 |
-
|
| 6971 |
-
int32_t t = masked ? 1 : 0;
|
| 6972 |
-
ggml_set_op_params(result, &t, sizeof(t));
|
| 6973 |
-
|
| 6974 |
-
result->op = GGML_OP_FLASH_ATTN;
|
| 6975 |
-
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
| 6976 |
-
result->src[0] = q;
|
| 6977 |
-
result->src[1] = k;
|
| 6978 |
-
result->src[2] = v;
|
| 6979 |
-
|
| 6980 |
-
return result;
|
| 6981 |
-
}
|
| 6982 |
-
|
| 6983 |
// ggml_flash_attn_ext
|
| 6984 |
|
| 6985 |
struct ggml_tensor * ggml_flash_attn_ext(
|
|
@@ -7039,38 +7003,6 @@ void ggml_flash_attn_ext_set_prec(
|
|
| 7039 |
ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
|
| 7040 |
}
|
| 7041 |
|
| 7042 |
-
// ggml_flash_ff
|
| 7043 |
-
|
| 7044 |
-
struct ggml_tensor * ggml_flash_ff(
|
| 7045 |
-
struct ggml_context * ctx,
|
| 7046 |
-
struct ggml_tensor * a,
|
| 7047 |
-
struct ggml_tensor * b0,
|
| 7048 |
-
struct ggml_tensor * b1,
|
| 7049 |
-
struct ggml_tensor * c0,
|
| 7050 |
-
struct ggml_tensor * c1) {
|
| 7051 |
-
GGML_ASSERT(ggml_can_mul_mat(b0, a));
|
| 7052 |
-
// TODO: more checks
|
| 7053 |
-
|
| 7054 |
-
bool is_node = false;
|
| 7055 |
-
|
| 7056 |
-
if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
|
| 7057 |
-
is_node = true;
|
| 7058 |
-
}
|
| 7059 |
-
|
| 7060 |
-
//struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
| 7061 |
-
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, a->ne);
|
| 7062 |
-
|
| 7063 |
-
result->op = GGML_OP_FLASH_FF;
|
| 7064 |
-
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
| 7065 |
-
result->src[0] = a;
|
| 7066 |
-
result->src[1] = b0;
|
| 7067 |
-
result->src[2] = b1;
|
| 7068 |
-
result->src[3] = c0;
|
| 7069 |
-
result->src[4] = c1;
|
| 7070 |
-
|
| 7071 |
-
return result;
|
| 7072 |
-
}
|
| 7073 |
-
|
| 7074 |
// ggml_flash_attn_back
|
| 7075 |
|
| 7076 |
struct ggml_tensor * ggml_flash_attn_back(
|
|
@@ -7080,6 +7012,8 @@ struct ggml_tensor * ggml_flash_attn_back(
|
|
| 7080 |
struct ggml_tensor * v,
|
| 7081 |
struct ggml_tensor * d,
|
| 7082 |
bool masked) {
|
|
|
|
|
|
|
| 7083 |
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
| 7084 |
// TODO: check if vT can be multiplied by (k*qT)
|
| 7085 |
|
|
@@ -15709,400 +15643,6 @@ static void ggml_compute_forward_argsort(
|
|
| 15709 |
}
|
| 15710 |
}
|
| 15711 |
|
| 15712 |
-
// ggml_compute_forward_flash_attn
|
| 15713 |
-
|
| 15714 |
-
static void ggml_compute_forward_flash_attn_f32(
|
| 15715 |
-
const struct ggml_compute_params * params,
|
| 15716 |
-
const bool masked,
|
| 15717 |
-
struct ggml_tensor * dst) {
|
| 15718 |
-
|
| 15719 |
-
const struct ggml_tensor * q = dst->src[0];
|
| 15720 |
-
const struct ggml_tensor * k = dst->src[1];
|
| 15721 |
-
const struct ggml_tensor * v = dst->src[2];
|
| 15722 |
-
|
| 15723 |
-
int64_t t0 = ggml_perf_time_us();
|
| 15724 |
-
UNUSED(t0);
|
| 15725 |
-
|
| 15726 |
-
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
| 15727 |
-
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
| 15728 |
-
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
| 15729 |
-
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
| 15730 |
-
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
| 15731 |
-
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
| 15732 |
-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
| 15733 |
-
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
| 15734 |
-
|
| 15735 |
-
const int ith = params->ith;
|
| 15736 |
-
const int nth = params->nth;
|
| 15737 |
-
|
| 15738 |
-
const int64_t D = neq0;
|
| 15739 |
-
const int64_t N = neq1;
|
| 15740 |
-
const int64_t P = nek1 - N;
|
| 15741 |
-
const int64_t M = P + N;
|
| 15742 |
-
|
| 15743 |
-
const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
|
| 15744 |
-
|
| 15745 |
-
GGML_ASSERT(ne0 == D);
|
| 15746 |
-
GGML_ASSERT(ne1 == N);
|
| 15747 |
-
GGML_ASSERT(P >= 0);
|
| 15748 |
-
|
| 15749 |
-
GGML_ASSERT(nbq0 == sizeof(float));
|
| 15750 |
-
GGML_ASSERT(nbk0 == sizeof(float));
|
| 15751 |
-
GGML_ASSERT(nbv0 == sizeof(float));
|
| 15752 |
-
|
| 15753 |
-
GGML_ASSERT(neq0 == D);
|
| 15754 |
-
GGML_ASSERT(nek0 == D);
|
| 15755 |
-
GGML_ASSERT(nev1 == D);
|
| 15756 |
-
|
| 15757 |
-
GGML_ASSERT(neq1 == N);
|
| 15758 |
-
GGML_ASSERT(nek1 == N + P);
|
| 15759 |
-
GGML_ASSERT(nev1 == D);
|
| 15760 |
-
|
| 15761 |
-
// dst cannot be transposed or permuted
|
| 15762 |
-
GGML_ASSERT(nb0 == sizeof(float));
|
| 15763 |
-
GGML_ASSERT(nb0 <= nb1);
|
| 15764 |
-
GGML_ASSERT(nb1 <= nb2);
|
| 15765 |
-
GGML_ASSERT(nb2 <= nb3);
|
| 15766 |
-
|
| 15767 |
-
if (params->type == GGML_TASK_TYPE_INIT) {
|
| 15768 |
-
return;
|
| 15769 |
-
}
|
| 15770 |
-
|
| 15771 |
-
if (params->type == GGML_TASK_TYPE_FINALIZE) {
|
| 15772 |
-
return;
|
| 15773 |
-
}
|
| 15774 |
-
|
| 15775 |
-
// parallelize by q rows using ggml_vec_dot_f32
|
| 15776 |
-
|
| 15777 |
-
// total rows in q
|
| 15778 |
-
const int nr = neq1*neq2*neq3;
|
| 15779 |
-
|
| 15780 |
-
// rows per thread
|
| 15781 |
-
const int dr = (nr + nth - 1)/nth;
|
| 15782 |
-
|
| 15783 |
-
// row range for this thread
|
| 15784 |
-
const int ir0 = dr*ith;
|
| 15785 |
-
const int ir1 = MIN(ir0 + dr, nr);
|
| 15786 |
-
|
| 15787 |
-
const float scale = 1.0f/sqrtf(D);
|
| 15788 |
-
|
| 15789 |
-
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
|
| 15790 |
-
|
| 15791 |
-
for (int ir = ir0; ir < ir1; ++ir) {
|
| 15792 |
-
// q indices
|
| 15793 |
-
const int iq3 = ir/(neq2*neq1);
|
| 15794 |
-
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
| 15795 |
-
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
| 15796 |
-
|
| 15797 |
-
float * S = (float *) params->wdata + ith*(Mup + CACHE_LINE_SIZE_F32);
|
| 15798 |
-
|
| 15799 |
-
for (int i = M; i < Mup; ++i) {
|
| 15800 |
-
S[i] = -INFINITY;
|
| 15801 |
-
}
|
| 15802 |
-
|
| 15803 |
-
const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
|
| 15804 |
-
for (int64_t ic = 0; ic < masked_begin; ++ic) {
|
| 15805 |
-
// k indices
|
| 15806 |
-
const int ik3 = iq3;
|
| 15807 |
-
const int ik2 = iq2 % nek2;
|
| 15808 |
-
const int ik1 = ic;
|
| 15809 |
-
|
| 15810 |
-
// S indices
|
| 15811 |
-
const int i1 = ik1;
|
| 15812 |
-
|
| 15813 |
-
ggml_vec_dot_f32(neq0,
|
| 15814 |
-
S + i1, 0,
|
| 15815 |
-
(float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
|
| 15816 |
-
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
|
| 15817 |
-
}
|
| 15818 |
-
|
| 15819 |
-
// scale
|
| 15820 |
-
ggml_vec_scale_f32(masked_begin, S, scale);
|
| 15821 |
-
|
| 15822 |
-
for (int64_t i = masked_begin; i < M; i++) {
|
| 15823 |
-
S[i] = -INFINITY;
|
| 15824 |
-
}
|
| 15825 |
-
|
| 15826 |
-
// softmax
|
| 15827 |
-
// exclude known -INF S[..] values from max and loop
|
| 15828 |
-
// dont forget to set their SW values to zero
|
| 15829 |
-
{
|
| 15830 |
-
float max = -INFINITY;
|
| 15831 |
-
ggml_vec_max_f32(masked_begin, &max, S);
|
| 15832 |
-
|
| 15833 |
-
ggml_float sum = 0.0;
|
| 15834 |
-
{
|
| 15835 |
-
#ifdef GGML_SOFT_MAX_ACCELERATE
|
| 15836 |
-
max = -max;
|
| 15837 |
-
vDSP_vsadd(S, 1, &max, S, 1, Mup);
|
| 15838 |
-
vvexpf(S, S, &Mup);
|
| 15839 |
-
ggml_vec_sum_f32(Mup, &sum, S);
|
| 15840 |
-
#else
|
| 15841 |
-
sum = ggml_vec_soft_max_f32(Mup, S, S, max);
|
| 15842 |
-
#endif
|
| 15843 |
-
}
|
| 15844 |
-
|
| 15845 |
-
assert(sum > 0.0);
|
| 15846 |
-
|
| 15847 |
-
sum = 1.0/sum;
|
| 15848 |
-
ggml_vec_scale_f32(masked_begin, S, sum);
|
| 15849 |
-
|
| 15850 |
-
#ifndef NDEBUG
|
| 15851 |
-
for (int i = 0; i < masked_begin; ++i) {
|
| 15852 |
-
assert(!isnan(S[i]));
|
| 15853 |
-
assert(!isinf(S[i]));
|
| 15854 |
-
}
|
| 15855 |
-
#endif
|
| 15856 |
-
}
|
| 15857 |
-
|
| 15858 |
-
for (int64_t ic = 0; ic < nev1; ++ic) {
|
| 15859 |
-
// dst indices
|
| 15860 |
-
const int i1 = iq1;
|
| 15861 |
-
const int i2 = iq2;
|
| 15862 |
-
const int i3 = iq3;
|
| 15863 |
-
|
| 15864 |
-
// v indices
|
| 15865 |
-
const int iv2 = iq2 % nev2;
|
| 15866 |
-
const int iv3 = iq3;
|
| 15867 |
-
|
| 15868 |
-
ggml_vec_dot_f32(masked_begin,
|
| 15869 |
-
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
|
| 15870 |
-
(float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
|
| 15871 |
-
S, 0, 1);
|
| 15872 |
-
}
|
| 15873 |
-
}
|
| 15874 |
-
}
|
| 15875 |
-
|
| 15876 |
-
static void ggml_compute_forward_flash_attn_f16(
|
| 15877 |
-
const struct ggml_compute_params * params,
|
| 15878 |
-
const bool masked,
|
| 15879 |
-
struct ggml_tensor * dst) {
|
| 15880 |
-
|
| 15881 |
-
const struct ggml_tensor * q = dst->src[0];
|
| 15882 |
-
const struct ggml_tensor * k = dst->src[1];
|
| 15883 |
-
const struct ggml_tensor * v = dst->src[2];
|
| 15884 |
-
|
| 15885 |
-
int64_t t0 = ggml_perf_time_us();
|
| 15886 |
-
UNUSED(t0);
|
| 15887 |
-
|
| 15888 |
-
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
| 15889 |
-
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
| 15890 |
-
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
| 15891 |
-
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
| 15892 |
-
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
| 15893 |
-
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
| 15894 |
-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
| 15895 |
-
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
| 15896 |
-
|
| 15897 |
-
const int ith = params->ith;
|
| 15898 |
-
const int nth = params->nth;
|
| 15899 |
-
|
| 15900 |
-
const int64_t D = neq0;
|
| 15901 |
-
const int64_t N = neq1;
|
| 15902 |
-
const int64_t P = nek1 - N;
|
| 15903 |
-
const int64_t M = P + N;
|
| 15904 |
-
|
| 15905 |
-
const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
|
| 15906 |
-
|
| 15907 |
-
GGML_ASSERT(ne0 == D);
|
| 15908 |
-
GGML_ASSERT(ne1 == N);
|
| 15909 |
-
GGML_ASSERT(P >= 0);
|
| 15910 |
-
|
| 15911 |
-
GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t));
|
| 15912 |
-
GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
|
| 15913 |
-
GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
|
| 15914 |
-
|
| 15915 |
-
GGML_ASSERT(neq0 == D);
|
| 15916 |
-
GGML_ASSERT(nek0 == D);
|
| 15917 |
-
GGML_ASSERT(nev1 == D);
|
| 15918 |
-
|
| 15919 |
-
GGML_ASSERT(neq1 == N);
|
| 15920 |
-
GGML_ASSERT(nek1 == N + P);
|
| 15921 |
-
GGML_ASSERT(nev1 == D);
|
| 15922 |
-
|
| 15923 |
-
// dst cannot be transposed or permuted
|
| 15924 |
-
GGML_ASSERT(nb0 == sizeof(float));
|
| 15925 |
-
GGML_ASSERT(nb0 <= nb1);
|
| 15926 |
-
GGML_ASSERT(nb1 <= nb2);
|
| 15927 |
-
GGML_ASSERT(nb2 <= nb3);
|
| 15928 |
-
|
| 15929 |
-
if (params->type == GGML_TASK_TYPE_INIT) {
|
| 15930 |
-
return;
|
| 15931 |
-
}
|
| 15932 |
-
|
| 15933 |
-
if (params->type == GGML_TASK_TYPE_FINALIZE) {
|
| 15934 |
-
return;
|
| 15935 |
-
}
|
| 15936 |
-
|
| 15937 |
-
// parallelize by q rows using ggml_vec_dot_f32
|
| 15938 |
-
|
| 15939 |
-
// total rows in q
|
| 15940 |
-
const int nr = neq1*neq2*neq3;
|
| 15941 |
-
|
| 15942 |
-
// rows per thread
|
| 15943 |
-
const int dr = (nr + nth - 1)/nth;
|
| 15944 |
-
|
| 15945 |
-
// row range for this thread
|
| 15946 |
-
const int ir0 = dr*ith;
|
| 15947 |
-
const int ir1 = MIN(ir0 + dr, nr);
|
| 15948 |
-
|
| 15949 |
-
const float scale = 1.0f/sqrtf(D);
|
| 15950 |
-
|
| 15951 |
-
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
|
| 15952 |
-
|
| 15953 |
-
for (int ir = ir0; ir < ir1; ++ir) {
|
| 15954 |
-
// q indices
|
| 15955 |
-
const int iq3 = ir/(neq2*neq1);
|
| 15956 |
-
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
| 15957 |
-
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
| 15958 |
-
|
| 15959 |
-
float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32);
|
| 15960 |
-
|
| 15961 |
-
for (int i = M; i < Mup; ++i) {
|
| 15962 |
-
S[i] = -INFINITY;
|
| 15963 |
-
}
|
| 15964 |
-
|
| 15965 |
-
if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) {
|
| 15966 |
-
for (int64_t ic = 0; ic < nek1; ++ic) {
|
| 15967 |
-
// k indices
|
| 15968 |
-
const int ik3 = iq3;
|
| 15969 |
-
const int ik2 = iq2 % nek2;
|
| 15970 |
-
const int ik1 = ic;
|
| 15971 |
-
|
| 15972 |
-
// S indices
|
| 15973 |
-
const int i1 = ik1;
|
| 15974 |
-
|
| 15975 |
-
ggml_vec_dot_f16(neq0,
|
| 15976 |
-
S + i1, 0,
|
| 15977 |
-
(ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
|
| 15978 |
-
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
|
| 15979 |
-
}
|
| 15980 |
-
} else {
|
| 15981 |
-
for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
|
| 15982 |
-
// k indices
|
| 15983 |
-
const int ik3 = iq3;
|
| 15984 |
-
const int ik2 = iq2 % nek2;
|
| 15985 |
-
const int ik1 = ic;
|
| 15986 |
-
|
| 15987 |
-
// S indices
|
| 15988 |
-
const int i1 = ik1;
|
| 15989 |
-
|
| 15990 |
-
ggml_vec_dot_f16_unroll(neq0, nbk1,
|
| 15991 |
-
S + i1,
|
| 15992 |
-
((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
| 15993 |
-
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
| 15994 |
-
}
|
| 15995 |
-
}
|
| 15996 |
-
|
| 15997 |
-
// scale
|
| 15998 |
-
ggml_vec_scale_f32(nek1, S, scale);
|
| 15999 |
-
|
| 16000 |
-
if (masked) {
|
| 16001 |
-
for (int64_t i = P; i < M; i++) {
|
| 16002 |
-
if (i > P + iq1) {
|
| 16003 |
-
S[i] = -INFINITY;
|
| 16004 |
-
}
|
| 16005 |
-
}
|
| 16006 |
-
}
|
| 16007 |
-
|
| 16008 |
-
// softmax
|
| 16009 |
-
// todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
|
| 16010 |
-
// dont forget to set their S values to zero
|
| 16011 |
-
{
|
| 16012 |
-
float max = -INFINITY;
|
| 16013 |
-
ggml_vec_max_f32(M, &max, S);
|
| 16014 |
-
|
| 16015 |
-
ggml_float sum = 0.0;
|
| 16016 |
-
{
|
| 16017 |
-
#ifdef GGML_SOFT_MAX_ACCELERATE
|
| 16018 |
-
max = -max;
|
| 16019 |
-
vDSP_vsadd(S, 1, &max, S, 1, Mup);
|
| 16020 |
-
vvexpf(S, S, &Mup);
|
| 16021 |
-
ggml_vec_sum_f32(Mup, &sum, S);
|
| 16022 |
-
#else
|
| 16023 |
-
sum = ggml_vec_soft_max_f32(Mup, S, S, max);
|
| 16024 |
-
#endif
|
| 16025 |
-
}
|
| 16026 |
-
|
| 16027 |
-
assert(sum > 0.0);
|
| 16028 |
-
|
| 16029 |
-
sum = 1.0/sum;
|
| 16030 |
-
ggml_vec_scale_f32(M, S, sum);
|
| 16031 |
-
|
| 16032 |
-
#ifndef NDEBUG
|
| 16033 |
-
for (int i = 0; i < M; ++i) {
|
| 16034 |
-
assert(!isnan(S[i]));
|
| 16035 |
-
assert(!isinf(S[i]));
|
| 16036 |
-
}
|
| 16037 |
-
#endif
|
| 16038 |
-
}
|
| 16039 |
-
|
| 16040 |
-
ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup);
|
| 16041 |
-
|
| 16042 |
-
for (int64_t i = 0; i < M; i++) {
|
| 16043 |
-
S16[i] = GGML_FP32_TO_FP16(S[i]);
|
| 16044 |
-
}
|
| 16045 |
-
|
| 16046 |
-
// todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16).
|
| 16047 |
-
if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
|
| 16048 |
-
for (int64_t ic = 0; ic < nev1; ++ic) {
|
| 16049 |
-
// dst indices
|
| 16050 |
-
const int i1 = iq1;
|
| 16051 |
-
const int i2 = iq2;
|
| 16052 |
-
const int i3 = iq3;
|
| 16053 |
-
|
| 16054 |
-
// v indices
|
| 16055 |
-
const int iv2 = iq2 % nev2;
|
| 16056 |
-
const int iv3 = iq3;
|
| 16057 |
-
|
| 16058 |
-
ggml_vec_dot_f16(nev0,
|
| 16059 |
-
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
|
| 16060 |
-
(ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
|
| 16061 |
-
S16, 0, 1);
|
| 16062 |
-
}
|
| 16063 |
-
} else {
|
| 16064 |
-
for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
|
| 16065 |
-
// dst indices
|
| 16066 |
-
const int i1 = iq1;
|
| 16067 |
-
const int i2 = iq2;
|
| 16068 |
-
const int i3 = iq3;
|
| 16069 |
-
|
| 16070 |
-
// v indices
|
| 16071 |
-
const int iv2 = iq2 % nev2;
|
| 16072 |
-
const int iv3 = iq3;
|
| 16073 |
-
|
| 16074 |
-
ggml_vec_dot_f16_unroll(nev0, nbv1,
|
| 16075 |
-
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
| 16076 |
-
((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
|
| 16077 |
-
S16);
|
| 16078 |
-
}
|
| 16079 |
-
}
|
| 16080 |
-
}
|
| 16081 |
-
}
|
| 16082 |
-
|
| 16083 |
-
static void ggml_compute_forward_flash_attn(
|
| 16084 |
-
const struct ggml_compute_params * params,
|
| 16085 |
-
const bool masked,
|
| 16086 |
-
struct ggml_tensor * dst) {
|
| 16087 |
-
|
| 16088 |
-
const struct ggml_tensor * q = dst->src[0];
|
| 16089 |
-
|
| 16090 |
-
switch (q->type) {
|
| 16091 |
-
case GGML_TYPE_F16:
|
| 16092 |
-
{
|
| 16093 |
-
ggml_compute_forward_flash_attn_f16(params, masked, dst);
|
| 16094 |
-
} break;
|
| 16095 |
-
case GGML_TYPE_F32:
|
| 16096 |
-
{
|
| 16097 |
-
ggml_compute_forward_flash_attn_f32(params, masked, dst);
|
| 16098 |
-
} break;
|
| 16099 |
-
default:
|
| 16100 |
-
{
|
| 16101 |
-
GGML_ASSERT(false);
|
| 16102 |
-
} break;
|
| 16103 |
-
}
|
| 16104 |
-
}
|
| 16105 |
-
|
| 16106 |
// ggml_compute_forward_flash_attn_ext
|
| 16107 |
|
| 16108 |
static void ggml_compute_forward_flash_attn_ext_f16(
|
|
@@ -16336,165 +15876,6 @@ static void ggml_compute_forward_flash_attn_ext(
|
|
| 16336 |
}
|
| 16337 |
}
|
| 16338 |
|
| 16339 |
-
// ggml_compute_forward_flash_ff
|
| 16340 |
-
|
| 16341 |
-
static void ggml_compute_forward_flash_ff_f16(
|
| 16342 |
-
const struct ggml_compute_params * params,
|
| 16343 |
-
struct ggml_tensor * dst) {
|
| 16344 |
-
|
| 16345 |
-
const struct ggml_tensor * a = dst->src[0]; // F16
|
| 16346 |
-
const struct ggml_tensor * b0 = dst->src[1]; // F16 fc_w
|
| 16347 |
-
const struct ggml_tensor * b1 = dst->src[2]; // F32 fc_b
|
| 16348 |
-
const struct ggml_tensor * c0 = dst->src[3]; // F16 proj_w
|
| 16349 |
-
const struct ggml_tensor * c1 = dst->src[4]; // F32 proj_b
|
| 16350 |
-
|
| 16351 |
-
int64_t t0 = ggml_perf_time_us();
|
| 16352 |
-
UNUSED(t0);
|
| 16353 |
-
|
| 16354 |
-
GGML_TENSOR_LOCALS(int64_t, nea, a, ne)
|
| 16355 |
-
GGML_TENSOR_LOCALS(size_t, nba, a, nb)
|
| 16356 |
-
GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne)
|
| 16357 |
-
GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb)
|
| 16358 |
-
GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne)
|
| 16359 |
-
GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb)
|
| 16360 |
-
GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne)
|
| 16361 |
-
GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb)
|
| 16362 |
-
GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne)
|
| 16363 |
-
GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb)
|
| 16364 |
-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
| 16365 |
-
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
| 16366 |
-
|
| 16367 |
-
const int ith = params->ith;
|
| 16368 |
-
const int nth = params->nth;
|
| 16369 |
-
|
| 16370 |
-
const int64_t D = nea0;
|
| 16371 |
-
//const int64_t N = nea1;
|
| 16372 |
-
const int64_t M = neb01;
|
| 16373 |
-
|
| 16374 |
-
GGML_ASSERT(ne0 == nea0);
|
| 16375 |
-
GGML_ASSERT(ne1 == nea1);
|
| 16376 |
-
GGML_ASSERT(ne2 == nea2);
|
| 16377 |
-
|
| 16378 |
-
GGML_ASSERT(nba0 == sizeof(ggml_fp16_t));
|
| 16379 |
-
GGML_ASSERT(nbb00 == sizeof(ggml_fp16_t));
|
| 16380 |
-
GGML_ASSERT(nbb10 == sizeof(float));
|
| 16381 |
-
GGML_ASSERT(nbc00 == sizeof(ggml_fp16_t));
|
| 16382 |
-
GGML_ASSERT(nbc10 == sizeof(float));
|
| 16383 |
-
|
| 16384 |
-
GGML_ASSERT(neb00 == D);
|
| 16385 |
-
GGML_ASSERT(neb01 == M);
|
| 16386 |
-
GGML_ASSERT(neb10 == M);
|
| 16387 |
-
GGML_ASSERT(neb11 == 1);
|
| 16388 |
-
|
| 16389 |
-
GGML_ASSERT(nec00 == M);
|
| 16390 |
-
GGML_ASSERT(nec01 == D);
|
| 16391 |
-
GGML_ASSERT(nec10 == D);
|
| 16392 |
-
GGML_ASSERT(nec11 == 1);
|
| 16393 |
-
|
| 16394 |
-
// dst cannot be transposed or permuted
|
| 16395 |
-
GGML_ASSERT(nb0 == sizeof(float));
|
| 16396 |
-
GGML_ASSERT(nb0 <= nb1);
|
| 16397 |
-
GGML_ASSERT(nb1 <= nb2);
|
| 16398 |
-
GGML_ASSERT(nb2 <= nb3);
|
| 16399 |
-
|
| 16400 |
-
if (params->type == GGML_TASK_TYPE_INIT) {
|
| 16401 |
-
return;
|
| 16402 |
-
}
|
| 16403 |
-
|
| 16404 |
-
if (params->type == GGML_TASK_TYPE_FINALIZE) {
|
| 16405 |
-
return;
|
| 16406 |
-
}
|
| 16407 |
-
|
| 16408 |
-
// parallelize by a rows using ggml_vec_dot_f32
|
| 16409 |
-
|
| 16410 |
-
// total rows in a
|
| 16411 |
-
const int nr = nea1*nea2*nea3;
|
| 16412 |
-
|
| 16413 |
-
// rows per thread
|
| 16414 |
-
const int dr = (nr + nth - 1)/nth;
|
| 16415 |
-
|
| 16416 |
-
// row range for this thread
|
| 16417 |
-
const int ir0 = dr*ith;
|
| 16418 |
-
const int ir1 = MIN(ir0 + dr, nr);
|
| 16419 |
-
|
| 16420 |
-
for (int ir = ir0; ir < ir1; ++ir) {
|
| 16421 |
-
// a indices
|
| 16422 |
-
const int ia3 = ir/(nea2*nea1);
|
| 16423 |
-
const int ia2 = (ir - ia3*nea2*nea1)/nea1;
|
| 16424 |
-
const int ia1 = (ir - ia3*nea2*nea1 - ia2*nea1);
|
| 16425 |
-
|
| 16426 |
-
float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32);
|
| 16427 |
-
|
| 16428 |
-
for (int64_t ic = 0; ic < neb01; ++ic) {
|
| 16429 |
-
// b0 indices
|
| 16430 |
-
const int ib03 = ia3;
|
| 16431 |
-
const int ib02 = ia2;
|
| 16432 |
-
const int ib01 = ic;
|
| 16433 |
-
|
| 16434 |
-
// S indices
|
| 16435 |
-
const int i1 = ib01;
|
| 16436 |
-
|
| 16437 |
-
ggml_vec_dot_f16(nea0,
|
| 16438 |
-
S + i1, 0,
|
| 16439 |
-
(ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), 0,
|
| 16440 |
-
(ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3)), 0, 1);
|
| 16441 |
-
}
|
| 16442 |
-
|
| 16443 |
-
ggml_vec_add_f32(neb01, S, S, (float *) b1->data);
|
| 16444 |
-
//ggml_vec_gelu_f32(neb01, S, S);
|
| 16445 |
-
|
| 16446 |
-
ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
|
| 16447 |
-
|
| 16448 |
-
for (int64_t i = 0; i < M; i++) {
|
| 16449 |
-
S16[i] = GGML_FP32_TO_FP16(S[i]);
|
| 16450 |
-
}
|
| 16451 |
-
|
| 16452 |
-
ggml_vec_gelu_f16(neb01, S16, S16);
|
| 16453 |
-
|
| 16454 |
-
{
|
| 16455 |
-
// dst indices
|
| 16456 |
-
const int i1 = ia1;
|
| 16457 |
-
const int i2 = ia2;
|
| 16458 |
-
const int i3 = ia3;
|
| 16459 |
-
|
| 16460 |
-
for (int64_t ic = 0; ic < nec01; ++ic) {
|
| 16461 |
-
|
| 16462 |
-
ggml_vec_dot_f16(neb01,
|
| 16463 |
-
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
|
| 16464 |
-
(ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)), 0,
|
| 16465 |
-
S16, 0, 1);
|
| 16466 |
-
}
|
| 16467 |
-
|
| 16468 |
-
ggml_vec_add_f32(nec01,
|
| 16469 |
-
(float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
|
| 16470 |
-
(float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
|
| 16471 |
-
(float *) c1->data);
|
| 16472 |
-
}
|
| 16473 |
-
}
|
| 16474 |
-
}
|
| 16475 |
-
|
| 16476 |
-
static void ggml_compute_forward_flash_ff(
|
| 16477 |
-
const struct ggml_compute_params * params,
|
| 16478 |
-
struct ggml_tensor * dst) {
|
| 16479 |
-
|
| 16480 |
-
const struct ggml_tensor * b0 = dst->src[1];
|
| 16481 |
-
|
| 16482 |
-
switch (b0->type) {
|
| 16483 |
-
case GGML_TYPE_F16:
|
| 16484 |
-
{
|
| 16485 |
-
ggml_compute_forward_flash_ff_f16(params, dst);
|
| 16486 |
-
} break;
|
| 16487 |
-
case GGML_TYPE_F32:
|
| 16488 |
-
{
|
| 16489 |
-
GGML_ASSERT(false); // TODO
|
| 16490 |
-
} break;
|
| 16491 |
-
default:
|
| 16492 |
-
{
|
| 16493 |
-
GGML_ASSERT(false);
|
| 16494 |
-
} break;
|
| 16495 |
-
}
|
| 16496 |
-
}
|
| 16497 |
-
|
| 16498 |
// ggml_compute_forward_flash_attn_back
|
| 16499 |
|
| 16500 |
static void ggml_compute_forward_flash_attn_back_f32(
|
|
@@ -18065,21 +17446,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
| 18065 |
{
|
| 18066 |
ggml_compute_forward_leaky_relu(params, tensor);
|
| 18067 |
} break;
|
| 18068 |
-
case GGML_OP_FLASH_ATTN:
|
| 18069 |
-
{
|
| 18070 |
-
const int32_t t = ggml_get_op_params_i32(tensor, 0);
|
| 18071 |
-
GGML_ASSERT(t == 0 || t == 1);
|
| 18072 |
-
const bool masked = t != 0;
|
| 18073 |
-
ggml_compute_forward_flash_attn(params, masked, tensor);
|
| 18074 |
-
} break;
|
| 18075 |
case GGML_OP_FLASH_ATTN_EXT:
|
| 18076 |
{
|
| 18077 |
ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
|
| 18078 |
} break;
|
| 18079 |
-
case GGML_OP_FLASH_FF:
|
| 18080 |
-
{
|
| 18081 |
-
ggml_compute_forward_flash_ff(params, tensor);
|
| 18082 |
-
} break;
|
| 18083 |
case GGML_OP_FLASH_ATTN_BACK:
|
| 18084 |
{
|
| 18085 |
int32_t t = ggml_get_op_params_i32(tensor, 0);
|
|
@@ -19086,7 +18456,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 19086 |
{
|
| 19087 |
GGML_ASSERT(false); // TODO: not implemented
|
| 19088 |
} break;
|
| 19089 |
-
case GGML_OP_FLASH_ATTN:
|
| 19090 |
case GGML_OP_FLASH_ATTN_EXT:
|
| 19091 |
{
|
| 19092 |
struct ggml_tensor * flash_grad = NULL;
|
|
@@ -19140,10 +18509,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 19140 |
zero_table);
|
| 19141 |
}
|
| 19142 |
} break;
|
| 19143 |
-
case GGML_OP_FLASH_FF:
|
| 19144 |
-
{
|
| 19145 |
-
GGML_ASSERT(false); // not supported
|
| 19146 |
-
} break;
|
| 19147 |
case GGML_OP_FLASH_ATTN_BACK:
|
| 19148 |
{
|
| 19149 |
GGML_ASSERT(false); // not supported
|
|
@@ -19830,15 +19195,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
|
|
| 19830 |
{
|
| 19831 |
n_tasks = n_threads;
|
| 19832 |
} break;
|
| 19833 |
-
case GGML_OP_FLASH_ATTN:
|
| 19834 |
case GGML_OP_FLASH_ATTN_EXT:
|
| 19835 |
{
|
| 19836 |
n_tasks = n_threads;
|
| 19837 |
} break;
|
| 19838 |
-
case GGML_OP_FLASH_FF:
|
| 19839 |
-
{
|
| 19840 |
-
n_tasks = n_threads;
|
| 19841 |
-
} break;
|
| 19842 |
case GGML_OP_FLASH_ATTN_BACK:
|
| 19843 |
{
|
| 19844 |
n_tasks = n_threads;
|
|
@@ -20235,40 +19595,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
|
| 20235 |
cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
|
| 20236 |
cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
|
| 20237 |
} break;
|
| 20238 |
-
case GGML_OP_FLASH_ATTN:
|
| 20239 |
-
{
|
| 20240 |
-
const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
|
| 20241 |
-
|
| 20242 |
-
if (node->src[1]->type == GGML_TYPE_F32) {
|
| 20243 |
-
cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
|
| 20244 |
-
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
|
| 20245 |
-
} else if (node->src[1]->type == GGML_TYPE_F16) {
|
| 20246 |
-
cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
|
| 20247 |
-
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
|
| 20248 |
-
} else if (node->src[1]->type == GGML_TYPE_BF16) {
|
| 20249 |
-
cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
|
| 20250 |
-
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
|
| 20251 |
-
}
|
| 20252 |
-
} break;
|
| 20253 |
case GGML_OP_FLASH_ATTN_EXT:
|
| 20254 |
{
|
| 20255 |
const int64_t ne00 = node->src[0]->ne[0]; // D
|
| 20256 |
|
| 20257 |
cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
|
| 20258 |
} break;
|
| 20259 |
-
case GGML_OP_FLASH_FF:
|
| 20260 |
-
{
|
| 20261 |
-
if (node->src[1]->type == GGML_TYPE_F32) {
|
| 20262 |
-
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
|
| 20263 |
-
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
|
| 20264 |
-
} else if (node->src[1]->type == GGML_TYPE_F16) {
|
| 20265 |
-
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
|
| 20266 |
-
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
|
| 20267 |
-
} else if (node->src[1]->type == GGML_TYPE_BF16) {
|
| 20268 |
-
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
|
| 20269 |
-
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
|
| 20270 |
-
}
|
| 20271 |
-
} break;
|
| 20272 |
case GGML_OP_FLASH_ATTN_BACK:
|
| 20273 |
{
|
| 20274 |
const int64_t D = node->src[0]->ne[0];
|
|
|
|
| 2670 |
"ARGSORT",
|
| 2671 |
"LEAKY_RELU",
|
| 2672 |
|
|
|
|
| 2673 |
"FLASH_ATTN_EXT",
|
|
|
|
| 2674 |
"FLASH_ATTN_BACK",
|
| 2675 |
"SSM_CONV",
|
| 2676 |
"SSM_SCAN",
|
|
|
|
| 2696 |
"CROSS_ENTROPY_LOSS_BACK",
|
| 2697 |
};
|
| 2698 |
|
| 2699 |
+
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
|
| 2700 |
|
| 2701 |
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
| 2702 |
"none",
|
|
|
|
| 2758 |
"argsort(x)",
|
| 2759 |
"leaky_relu(x)",
|
| 2760 |
|
|
|
|
| 2761 |
"flash_attn_ext(x)",
|
|
|
|
| 2762 |
"flash_attn_back(x)",
|
| 2763 |
"ssm_conv(x)",
|
| 2764 |
"ssm_scan(x)",
|
|
|
|
| 2784 |
"cross_entropy_loss_back(x,y)",
|
| 2785 |
};
|
| 2786 |
|
| 2787 |
+
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
|
| 2788 |
|
| 2789 |
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
| 2790 |
|
|
|
|
| 6944 |
return result;
|
| 6945 |
}
|
| 6946 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6947 |
// ggml_flash_attn_ext
|
| 6948 |
|
| 6949 |
struct ggml_tensor * ggml_flash_attn_ext(
|
|
|
|
| 7003 |
ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
|
| 7004 |
}
|
| 7005 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7006 |
// ggml_flash_attn_back
|
| 7007 |
|
| 7008 |
struct ggml_tensor * ggml_flash_attn_back(
|
|
|
|
| 7012 |
struct ggml_tensor * v,
|
| 7013 |
struct ggml_tensor * d,
|
| 7014 |
bool masked) {
|
| 7015 |
+
GGML_ASSERT(false && "TODO: adapt to ggml_flash_attn_ext() changes");
|
| 7016 |
+
|
| 7017 |
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
| 7018 |
// TODO: check if vT can be multiplied by (k*qT)
|
| 7019 |
|
|
|
|
| 15643 |
}
|
| 15644 |
}
|
| 15645 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15646 |
// ggml_compute_forward_flash_attn_ext
|
| 15647 |
|
| 15648 |
static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
|
| 15876 |
}
|
| 15877 |
}
|
| 15878 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15879 |
// ggml_compute_forward_flash_attn_back
|
| 15880 |
|
| 15881 |
static void ggml_compute_forward_flash_attn_back_f32(
|
|
|
|
| 17446 |
{
|
| 17447 |
ggml_compute_forward_leaky_relu(params, tensor);
|
| 17448 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17449 |
case GGML_OP_FLASH_ATTN_EXT:
|
| 17450 |
{
|
| 17451 |
ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
|
| 17452 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17453 |
case GGML_OP_FLASH_ATTN_BACK:
|
| 17454 |
{
|
| 17455 |
int32_t t = ggml_get_op_params_i32(tensor, 0);
|
|
|
|
| 18456 |
{
|
| 18457 |
GGML_ASSERT(false); // TODO: not implemented
|
| 18458 |
} break;
|
|
|
|
| 18459 |
case GGML_OP_FLASH_ATTN_EXT:
|
| 18460 |
{
|
| 18461 |
struct ggml_tensor * flash_grad = NULL;
|
|
|
|
| 18509 |
zero_table);
|
| 18510 |
}
|
| 18511 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18512 |
case GGML_OP_FLASH_ATTN_BACK:
|
| 18513 |
{
|
| 18514 |
GGML_ASSERT(false); // not supported
|
|
|
|
| 19195 |
{
|
| 19196 |
n_tasks = n_threads;
|
| 19197 |
} break;
|
|
|
|
| 19198 |
case GGML_OP_FLASH_ATTN_EXT:
|
| 19199 |
{
|
| 19200 |
n_tasks = n_threads;
|
| 19201 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19202 |
case GGML_OP_FLASH_ATTN_BACK:
|
| 19203 |
{
|
| 19204 |
n_tasks = n_threads;
|
|
|
|
| 19595 |
cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
|
| 19596 |
cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
|
| 19597 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19598 |
case GGML_OP_FLASH_ATTN_EXT:
|
| 19599 |
{
|
| 19600 |
const int64_t ne00 = node->src[0]->ne[0]; // D
|
| 19601 |
|
| 19602 |
cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
|
| 19603 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19604 |
case GGML_OP_FLASH_ATTN_BACK:
|
| 19605 |
{
|
| 19606 |
const int64_t D = node->src[0]->ne[0];
|
ggml.h
CHANGED
|
@@ -481,9 +481,7 @@ extern "C" {
|
|
| 481 |
GGML_OP_ARGSORT,
|
| 482 |
GGML_OP_LEAKY_RELU,
|
| 483 |
|
| 484 |
-
GGML_OP_FLASH_ATTN,
|
| 485 |
GGML_OP_FLASH_ATTN_EXT,
|
| 486 |
-
GGML_OP_FLASH_FF,
|
| 487 |
GGML_OP_FLASH_ATTN_BACK,
|
| 488 |
GGML_OP_SSM_CONV,
|
| 489 |
GGML_OP_SSM_SCAN,
|
|
@@ -1761,13 +1759,6 @@ extern "C" {
|
|
| 1761 |
struct ggml_tensor * a,
|
| 1762 |
int k);
|
| 1763 |
|
| 1764 |
-
GGML_API struct ggml_tensor * ggml_flash_attn(
|
| 1765 |
-
struct ggml_context * ctx,
|
| 1766 |
-
struct ggml_tensor * q,
|
| 1767 |
-
struct ggml_tensor * k,
|
| 1768 |
-
struct ggml_tensor * v,
|
| 1769 |
-
bool masked);
|
| 1770 |
-
|
| 1771 |
#define GGML_KQ_MASK_PAD 32
|
| 1772 |
|
| 1773 |
// q: [n_embd, n_batch, n_head, 1]
|
|
@@ -1788,6 +1779,7 @@ extern "C" {
|
|
| 1788 |
struct ggml_tensor * a,
|
| 1789 |
enum ggml_prec prec);
|
| 1790 |
|
|
|
|
| 1791 |
GGML_API struct ggml_tensor * ggml_flash_attn_back(
|
| 1792 |
struct ggml_context * ctx,
|
| 1793 |
struct ggml_tensor * q,
|
|
@@ -1796,14 +1788,6 @@ extern "C" {
|
|
| 1796 |
struct ggml_tensor * d,
|
| 1797 |
bool masked);
|
| 1798 |
|
| 1799 |
-
GGML_API struct ggml_tensor * ggml_flash_ff(
|
| 1800 |
-
struct ggml_context * ctx,
|
| 1801 |
-
struct ggml_tensor * a,
|
| 1802 |
-
struct ggml_tensor * b0,
|
| 1803 |
-
struct ggml_tensor * b1,
|
| 1804 |
-
struct ggml_tensor * c0,
|
| 1805 |
-
struct ggml_tensor * c1);
|
| 1806 |
-
|
| 1807 |
GGML_API struct ggml_tensor * ggml_ssm_conv(
|
| 1808 |
struct ggml_context * ctx,
|
| 1809 |
struct ggml_tensor * s,
|
|
|
|
| 481 |
GGML_OP_ARGSORT,
|
| 482 |
GGML_OP_LEAKY_RELU,
|
| 483 |
|
|
|
|
| 484 |
GGML_OP_FLASH_ATTN_EXT,
|
|
|
|
| 485 |
GGML_OP_FLASH_ATTN_BACK,
|
| 486 |
GGML_OP_SSM_CONV,
|
| 487 |
GGML_OP_SSM_SCAN,
|
|
|
|
| 1759 |
struct ggml_tensor * a,
|
| 1760 |
int k);
|
| 1761 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1762 |
#define GGML_KQ_MASK_PAD 32
|
| 1763 |
|
| 1764 |
// q: [n_embd, n_batch, n_head, 1]
|
|
|
|
| 1779 |
struct ggml_tensor * a,
|
| 1780 |
enum ggml_prec prec);
|
| 1781 |
|
| 1782 |
+
// TODO: needs to be adapted to ggml_flash_attn_ext
|
| 1783 |
GGML_API struct ggml_tensor * ggml_flash_attn_back(
|
| 1784 |
struct ggml_context * ctx,
|
| 1785 |
struct ggml_tensor * q,
|
|
|
|
| 1788 |
struct ggml_tensor * d,
|
| 1789 |
bool masked);
|
| 1790 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1791 |
GGML_API struct ggml_tensor * ggml_ssm_conv(
|
| 1792 |
struct ggml_context * ctx,
|
| 1793 |
struct ggml_tensor * s,
|