katsu560 commited on
Commit
666b50a
·
1 Parent(s): add362d

Add AVX,AVX2 support for ggml_vec_scale_f32

Browse files
Files changed (1) hide show
  1. ggml.c +39 -1
ggml.c CHANGED
@@ -1118,7 +1118,45 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
1118
  #endif
1119
  }
1120
 
1121
- inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1122
  inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrt(*s); }
1123
  inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
1124
  inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrt(x[i]); }
 
1118
  #endif
1119
  }
1120
 
1121
+ //inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
1122
+ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
1123
+ #if defined(__AVX__) || defined(__AVX2__)
1124
+ // AVX 256-bit
1125
+ const int n32 = (n & ~31);
1126
+
1127
+ const __m256 v4 = _mm256_set1_ps(v);
1128
+
1129
+ __m256 y0, y1, y2, y3;
1130
+
1131
+ for (int i = 0; i < n32; i += 32) {
1132
+ y0 = _mm256_loadu_ps(y + i + 0);
1133
+ y1 = _mm256_loadu_ps(y + i + 8);
1134
+ y2 = _mm256_loadu_ps(y + i + 16);
1135
+ y3 = _mm256_loadu_ps(y + i + 24);
1136
+
1137
+ y0 = _mm256_mul_ps(y0, v4);
1138
+ y1 = _mm256_mul_ps(y1, v4);
1139
+ y2 = _mm256_mul_ps(y2, v4);
1140
+ y3 = _mm256_mul_ps(y3, v4);
1141
+
1142
+ _mm256_storeu_ps(y + i + 0, y0);
1143
+ _mm256_storeu_ps(y + i + 8, y1);
1144
+ _mm256_storeu_ps(y + i + 16, y2);
1145
+ _mm256_storeu_ps(y + i + 24, y3);
1146
+ }
1147
+
1148
+ // leftovers
1149
+ for (int i = n32; i < n; ++i) {
1150
+ y[i] *= v;
1151
+ }
1152
+ #else
1153
+ // scalar
1154
+ for (int i = 0; i < n; ++i) {
1155
+ y[i] *= v;
1156
+ }
1157
+ #endif
1158
+ }
1159
+
1160
  inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrt(*s); }
1161
  inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
1162
  inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrt(x[i]); }