Justine Tunney commited on
Commit
81ec961
Β·
1 Parent(s): 24e883a

ggml : introduce bfloat16 support (llama/6412)

Browse files

* Introduce bfloat16 support

Many models on Hugging Face (e.g. Mistral, TinyLLaMA) use bfloat16 as
their canonical floating point format.

β”Œsign
β”‚
β”‚ β”Œexponent
β”‚ β”‚
β”‚ β”‚ β”Œmantissa
β”‚ β”‚ β”‚
β”‚β”Œβ”€β”€β”΄β”€β”€β”€β”β”Œβ”€β”΄β”€β”€β”€β”
0b0000000000000000 brain16

This encoding has the same number of exponent bits as float32. That
makes conversion relatively straightforward, even in the absence of
hardware support. For example, converting brain16 to binary32 means
simply shifting 16 bits to the left.

β”Œsign
β”‚
β”‚ β”Œexponent
β”‚ β”‚
β”‚ β”‚ β”Œmantissa
β”‚ β”‚ β”‚
β”‚β”Œβ”€β”€β”΄β”€β”€β”€β”β”Œβ”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
0b00000000000000000000000000000000 IEEE binary32

The issue is that converting bf16 to fp16 can result in information
loss. Only 13% of bf16 numbers can be precisely represented in fp16
which in practice ends up being 99.71% of Mistral 7b v0.2's weights
however there is currently no way other than fp32 to get the others

β”Œsign
β”‚
β”‚ β”Œexponent
β”‚ β”‚
β”‚ β”‚ β”Œmantissa
β”‚ β”‚ β”‚
β”‚β”Œβ”€β”΄β”€β”β”Œβ”€β”΄β”€β”€β”€β”€β”€β”€β”
0b0000000000000000 IEEE binary16

This change fixes that, by adding a bf16 data type to GGML. Support
for CPU inference has been implemented along with optimizations for
the AVX2, AVX512, and AVX512BF16 ISAs. Perplexity on Mistral 7b 0.2
improves somewhere around -0.0024 to -0.0046 compared to using fp16

* Remove GGML code that's not needed

* Minimize the GGML API surface area for BF16

* Remove bf16 luts

* Make the GGML header look nicer

* Fix documentation

* Apply ggerganov's fixes for test-backend-ops

* Add BF16 code for new ggml_validate_row_data() function

Files changed (5) hide show
  1. ggml-impl.h +77 -0
  2. ggml-metal.m +1 -1
  3. ggml-quants.c +18 -0
  4. ggml.c +1089 -88
  5. ggml.h +15 -7
ggml-impl.h CHANGED
@@ -17,6 +17,83 @@
17
  #define MIN(a, b) ((a) < (b) ? (a) : (b))
18
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  #ifdef __cplusplus
21
  extern "C" {
22
  #endif
 
17
  #define MIN(a, b) ((a) < (b) ? (a) : (b))
18
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
19
 
20
+ /**
21
+ * Converts brain16 to float32.
22
+ *
23
+ * The bfloat16 floating point format has the following structure:
24
+ *
25
+ * β”Œsign
26
+ * β”‚
27
+ * β”‚ β”Œexponent
28
+ * β”‚ β”‚
29
+ * β”‚ β”‚ β”Œmantissa
30
+ * β”‚ β”‚ β”‚
31
+ * β”‚β”Œβ”€β”€β”΄β”€β”€β”€β”β”Œβ”€β”΄β”€β”€β”€β”
32
+ * 0b0000000000000000 brain16
33
+ *
34
+ * Since bf16 has the same number of exponent bits as a 32bit float,
35
+ * encoding and decoding numbers becomes relatively straightforward.
36
+ *
37
+ * β”Œsign
38
+ * β”‚
39
+ * β”‚ β”Œexponent
40
+ * β”‚ β”‚
41
+ * β”‚ β”‚ β”Œmantissa
42
+ * β”‚ β”‚ β”‚
43
+ * β”‚β”Œβ”€β”€β”΄β”€β”€β”€β”β”Œβ”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
44
+ * 0b00000000000000000000000000000000 IEEE binary32
45
+ *
46
+ * For comparison, the standard fp16 format has fewer exponent bits.
47
+ *
48
+ * β”Œsign
49
+ * β”‚
50
+ * β”‚ β”Œexponent
51
+ * β”‚ β”‚
52
+ * β”‚ β”‚ β”Œmantissa
53
+ * β”‚ β”‚ β”‚
54
+ * β”‚β”Œβ”€β”΄β”€β”β”Œβ”€β”΄β”€β”€β”€β”€β”€β”€β”
55
+ * 0b0000000000000000 IEEE binary16
56
+ *
57
+ * @see IEEE 754-2008
58
+ */
59
+ static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
60
+ union {
61
+ float f;
62
+ uint32_t i;
63
+ } u;
64
+ u.i = (uint32_t)h.bits << 16;
65
+ return u.f;
66
+ }
67
+
68
+ /**
69
+ * Converts float32 to brain16.
70
+ *
71
+ * This function is binary identical to AMD Zen4 VCVTNEPS2BF16.
72
+ * Subnormals shall be flushed to zero, and NANs will be quiet.
73
+ * This code should vectorize nicely if using modern compilers.
74
+ */
75
+ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
76
+ ggml_bf16_t h;
77
+ union {
78
+ float f;
79
+ uint32_t i;
80
+ } u;
81
+ u.f = s;
82
+ if ((u.i & 0x7fffffff) > 0x7f800000) { /* nan */
83
+ h.bits = (u.i >> 16) | 64; /* force to quiet */
84
+ return h;
85
+ }
86
+ if (!(u.i & 0x7f800000)) { /* subnormal */
87
+ h.bits = (u.i & 0x80000000) >> 16; /* flush to zero */
88
+ return h;
89
+ }
90
+ h.bits = (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16;
91
+ return h;
92
+ }
93
+
94
+ #define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
95
+ #define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
96
+
97
  #ifdef __cplusplus
98
  extern "C" {
99
  #endif
ggml-metal.m CHANGED
@@ -806,7 +806,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
806
  case GGML_OP_DIAG_MASK_INF:
807
  case GGML_OP_GET_ROWS:
808
  {
809
- return op->ne[3] == 1;
810
  }
811
  default:
812
  return false;
 
806
  case GGML_OP_DIAG_MASK_INF:
807
  case GGML_OP_GET_ROWS:
808
  {
809
+ return op->src[0]->type != GGML_TYPE_BF16 && op->ne[3] == 1;
810
  }
811
  default:
812
  return false;
ggml-quants.c CHANGED
@@ -12456,6 +12456,24 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
12456
  const size_t nb = nbytes/ggml_type_size(type);
12457
 
12458
  switch (type) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12459
  case GGML_TYPE_F16:
12460
  {
12461
  const ggml_fp16_t * f = (const ggml_fp16_t *) data;
 
12456
  const size_t nb = nbytes/ggml_type_size(type);
12457
 
12458
  switch (type) {
12459
+ case GGML_TYPE_BF16:
12460
+ {
12461
+ int nans = 0;
12462
+ int infs = 0;
12463
+ const unsigned short * f = (const unsigned short *) data;
12464
+ for (size_t i = 0; i < nb; ++i) {
12465
+ nans += (f[i] & 0x7fff) > 0x7f80;
12466
+ infs += (f[i] & 0x7fff) == 0x7f80;
12467
+ }
12468
+ if (nans) {
12469
+ fprintf(stderr, "%s: found %d NaNs in row of %zu BF16 values\n", __func__, nans, nb);
12470
+ return false;
12471
+ }
12472
+ if (infs) {
12473
+ fprintf(stderr, "%s: found %d infinities in row of %zu BF16 values\n", __func__, infs, nb);
12474
+ return false;
12475
+ }
12476
+ } break;
12477
  case GGML_TYPE_F16:
12478
  {
12479
  const ggml_fp16_t * f = (const ggml_fp16_t *) data;
ggml.c CHANGED
@@ -322,7 +322,7 @@ static ggml_fp16_t ggml_table_exp_f16[1 << 16];
322
  // precomputed f32 table for f16 (256 KB) (ggml-impl.h)
323
  float ggml_table_f32_f16[1 << 16];
324
 
325
- const char * ggml_status_to_string(enum ggml_status status) {
326
  switch (status) {
327
  case GGML_STATUS_ALLOC_FAILED: return "GGML status: error (failed to allocate memory)";
328
  case GGML_STATUS_FAILED: return "GGML status: error (operation failed)";
@@ -333,16 +333,26 @@ const char * ggml_status_to_string(enum ggml_status status) {
333
  return "GGML status: unknown";
334
  }
335
 
336
- // note: do not use these inside ggml.c
337
- // these are meant to be used via the ggml.h API
338
  float ggml_fp16_to_fp32(ggml_fp16_t x) {
 
339
  return GGML_FP16_TO_FP32(x);
340
  }
341
 
342
  ggml_fp16_t ggml_fp32_to_fp16(float x) {
 
343
  return GGML_FP32_TO_FP16(x);
344
  }
345
 
 
 
 
 
 
 
 
 
 
 
346
  void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) {
347
  for (int64_t i = 0; i < n; i++) {
348
  y[i] = GGML_FP16_TO_FP32(x[i]);
@@ -368,6 +378,49 @@ void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) {
368
  }
369
  }
370
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
  bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b) {
372
  return memcmp(guid_a, guid_b, sizeof(ggml_guid)) == 0;
373
  }
@@ -503,6 +556,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
503
 
504
  static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc);
505
  static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc);
 
506
 
507
  static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
508
  [GGML_TYPE_I8] = {
@@ -845,6 +899,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
845
  .type_size = sizeof(block_q8_K),
846
  .is_quantized = true,
847
  .from_float = quantize_row_q8_K,
 
 
 
 
 
 
 
 
 
 
 
 
848
  }
849
  };
850
 
@@ -1480,6 +1546,8 @@ inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) {
1480
 
1481
  inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1482
 
 
 
1483
  inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
1484
  inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; }
1485
  inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
@@ -1498,7 +1566,7 @@ static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float *
1498
  UNUSED(by);
1499
  UNUSED(bs);
1500
 
1501
- #ifdef GGML_SIMD
1502
  float sumf = 0.0f;
1503
  const int np = (n & ~(GGML_F32_STEP - 1));
1504
 
@@ -1534,6 +1602,70 @@ static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float *
1534
  *s = sumf;
1535
  }
1536
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1537
  static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc) {
1538
  assert(nrc == 1);
1539
  UNUSED(nrc);
@@ -1968,6 +2100,14 @@ inline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_
1968
  *s = sum;
1969
  }
1970
 
 
 
 
 
 
 
 
 
1971
  inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
1972
  #ifndef GGML_USE_ACCELERATE
1973
  float max = -INFINITY;
@@ -2379,7 +2519,7 @@ void ggml_numa_init(enum ggml_numa_strategy numa_flag) {
2379
  // figure out which node we're on
2380
  uint current_cpu;
2381
  int getcpu_ret = 0;
2382
- #if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28)
2383
  getcpu_ret = getcpu(&current_cpu, &g_state.numa.current_node);
2384
  #else
2385
  // old glibc doesn't have a wrapper for this call. Fall back on direct syscall
@@ -2590,6 +2730,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
2590
  switch (ftype) {
2591
  case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break;
2592
  case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break;
 
2593
  case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break;
2594
  case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break;
2595
  case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break;
@@ -2731,15 +2872,16 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
2731
  {
2732
  const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
2733
 
2734
- ggml_fp16_t ii;
2735
  for (int i = 0; i < (1 << 16); ++i) {
2736
- uint16_t ui = i;
2737
- memcpy(&ii, &ui, sizeof(ii));
2738
- const float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii);
 
 
2739
  ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
2740
  ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
2741
  ggml_table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
2742
- ggml_table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
2743
  }
2744
 
2745
  const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
@@ -3203,6 +3345,13 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
3203
  ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
3204
  }
3205
  } break;
 
 
 
 
 
 
 
3206
  case GGML_TYPE_F32:
3207
  {
3208
  assert(tensor->nb[0] == sizeof(float));
@@ -3255,6 +3404,13 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
3255
  ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
3256
  }
3257
  } break;
 
 
 
 
 
 
 
3258
  case GGML_TYPE_F32:
3259
  {
3260
  assert(tensor->nb[0] == sizeof(float));
@@ -3322,6 +3478,11 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
3322
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3323
  return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
3324
  }
 
 
 
 
 
3325
  case GGML_TYPE_F32:
3326
  {
3327
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
@@ -3364,6 +3525,11 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
3364
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3365
  ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
3366
  } break;
 
 
 
 
 
3367
  case GGML_TYPE_F32:
3368
  {
3369
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
@@ -3387,6 +3553,8 @@ int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i
3387
  return ((int32_t *) data)[0];
3388
  case GGML_TYPE_F16:
3389
  return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
 
 
3390
  case GGML_TYPE_F32:
3391
  return ((float *) data)[0];
3392
  default:
@@ -3415,6 +3583,10 @@ void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2,
3415
  {
3416
  ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
3417
  } break;
 
 
 
 
3418
  case GGML_TYPE_F32:
3419
  {
3420
  ((float *)(data))[0] = value;
@@ -3453,6 +3625,11 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
3453
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3454
  return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
3455
  }
 
 
 
 
 
3456
  case GGML_TYPE_F32:
3457
  {
3458
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
@@ -3495,6 +3672,11 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
3495
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3496
  ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
3497
  } break;
 
 
 
 
 
3498
  case GGML_TYPE_F32:
3499
  {
3500
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
@@ -3518,6 +3700,8 @@ float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2,
3518
  return ((int32_t *) data)[0];
3519
  case GGML_TYPE_F16:
3520
  return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
 
 
3521
  case GGML_TYPE_F32:
3522
  return ((float *) data)[0];
3523
  default:
@@ -3546,6 +3730,10 @@ void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2,
3546
  {
3547
  ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
3548
  } break;
 
 
 
 
3549
  case GGML_TYPE_F32:
3550
  {
3551
  ((float *)(data))[0] = value;
@@ -3740,7 +3928,11 @@ static struct ggml_tensor * ggml_add_cast_impl(
3740
  // TODO: support less-strict constraint
3741
  // GGML_ASSERT(ggml_can_repeat(b, a));
3742
  GGML_ASSERT(ggml_can_repeat_rows(b, a));
3743
- GGML_ASSERT(ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16); // currently only supported for quantized input and f16
 
 
 
 
3744
 
3745
  bool is_node = false;
3746
 
@@ -7231,8 +7423,8 @@ static void ggml_compute_forward_dup_same_cont(
7231
  ((char *) src0->data + ie0*nb00),
7232
  (ie1 - ie0) * ggml_type_size(src0->type));
7233
  }
7234
-
7235
  }
 
7236
  static void ggml_compute_forward_dup_f16(
7237
  const struct ggml_compute_params * params,
7238
  struct ggml_tensor * dst) {
@@ -7506,7 +7698,7 @@ static void ggml_compute_forward_dup_f16(
7506
  }
7507
  }
7508
 
7509
- static void ggml_compute_forward_dup_f32(
7510
  const struct ggml_compute_params * params,
7511
  struct ggml_tensor * dst) {
7512
 
@@ -7554,10 +7746,11 @@ static void ggml_compute_forward_dup_f32(
7554
  return;
7555
  }
7556
 
 
 
7557
  if (ggml_is_contiguous(dst)) {
7558
- // TODO: simplify
7559
- if (nb00 == sizeof(float)) {
7560
- if (dst->type == GGML_TYPE_F32) {
7561
  size_t id = 0;
7562
  const size_t rs = ne00 * nb00;
7563
  char * dst_ptr = (char *) dst->data;
@@ -7573,8 +7766,43 @@ static void ggml_compute_forward_dup_f32(
7573
  id += rs * (ne01 - ir1);
7574
  }
7575
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7576
  } else if (type_traits[dst->type].from_float) {
7577
  ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
 
7578
 
7579
  size_t id = 0;
7580
  size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
@@ -7584,8 +7812,13 @@ static void ggml_compute_forward_dup_f32(
7584
  for (int i02 = 0; i02 < ne02; i02++) {
7585
  id += rs * ir0;
7586
  for (int i01 = ir0; i01 < ir1; i01++) {
7587
- const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7588
- quantize_row_q(src0_ptr, dst_ptr + id, ne00);
 
 
 
 
 
7589
  id += rs;
7590
  }
7591
  id += rs * (ne01 - ir1);
@@ -7606,7 +7839,25 @@ static void ggml_compute_forward_dup_f32(
7606
  id += ne00 * ir0;
7607
  for (int i01 = ir0; i01 < ir1; i01++) {
7608
  for (int i00 = 0; i00 < ne00; i00++) {
7609
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7610
 
7611
  dst_ptr[id] = *src0_ptr;
7612
  id++;
@@ -7624,9 +7875,9 @@ static void ggml_compute_forward_dup_f32(
7624
  id += ne00 * ir0;
7625
  for (int i01 = ir0; i01 < ir1; i01++) {
7626
  for (int i00 = 0; i00 < ne00; i00++) {
7627
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7628
 
7629
- dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
7630
  id++;
7631
  }
7632
  }
@@ -7637,18 +7888,16 @@ static void ggml_compute_forward_dup_f32(
7637
  GGML_ASSERT(false); // TODO: implement
7638
  }
7639
  }
7640
-
7641
  return;
7642
  }
7643
 
7644
  // dst counters
7645
-
7646
  int64_t i10 = 0;
7647
  int64_t i11 = 0;
7648
  int64_t i12 = 0;
7649
  int64_t i13 = 0;
7650
 
7651
- if (dst->type == GGML_TYPE_F32) {
7652
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7653
  for (int64_t i02 = 0; i02 < ne02; i02++) {
7654
  i10 += ne00 * ir0;
@@ -7669,7 +7918,59 @@ static void ggml_compute_forward_dup_f32(
7669
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7670
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
7671
 
7672
- memcpy(dst_ptr, src0_ptr, sizeof(float));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7673
 
7674
  if (++i10 == ne0) {
7675
  i10 = 0;
@@ -7700,7 +8001,7 @@ static void ggml_compute_forward_dup_f32(
7700
  }
7701
  }
7702
  }
7703
- } else if (dst->type == GGML_TYPE_F16) {
7704
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7705
  for (int64_t i02 = 0; i02 < ne02; i02++) {
7706
  i10 += ne00 * ir0;
@@ -7721,7 +8022,7 @@ static void ggml_compute_forward_dup_f32(
7721
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7722
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
7723
 
7724
- *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
7725
 
7726
  if (++i10 == ne0) {
7727
  i10 = 0;
@@ -7757,31 +8058,27 @@ static void ggml_compute_forward_dup_f32(
7757
  }
7758
  }
7759
 
7760
- // A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
7761
- static void ggml_compute_forward_dup_bytes(
7762
  const struct ggml_compute_params * params,
7763
  struct ggml_tensor * dst) {
7764
 
7765
  const struct ggml_tensor * src0 = dst->src[0];
7766
 
7767
  GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
7768
- GGML_ASSERT(src0->type == dst->type);
7769
 
7770
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
7771
  return;
7772
  }
7773
 
7774
- if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
7775
- ggml_compute_forward_dup_same_cont(params, dst);
7776
- return;
7777
- }
7778
-
7779
- GGML_TENSOR_UNARY_OP_LOCALS;
7780
 
7781
- const size_t type_size = ggml_type_size(src0->type);
7782
  const int ith = params->ith; // thread index
7783
  const int nth = params->nth; // number of threads
7784
 
 
 
 
 
7785
 
7786
  // parallelize by rows
7787
  const int nr = ne01;
@@ -7793,9 +8090,9 @@ static void ggml_compute_forward_dup_bytes(
7793
 
7794
  if (src0->type == dst->type &&
7795
  ne00 == ne0 &&
7796
- nb00 == type_size && nb0 == type_size) {
7797
  // copy by rows
7798
- const size_t rs = ne00 * type_size;
7799
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7800
  for (int64_t i02 = 0; i02 < ne02; i02++) {
7801
  for (int64_t i01 = ir0; i01 < ir1; i01++) {
@@ -7810,41 +8107,366 @@ static void ggml_compute_forward_dup_bytes(
7810
  }
7811
 
7812
  if (ggml_is_contiguous(dst)) {
7813
- size_t id = 0;
7814
- char * dst_ptr = (char *) dst->data;
7815
- const size_t rs = ne00 * type_size;
7816
-
7817
- if (nb00 == type_size) {
7818
- // src0 is contigous on first dimension, copy by rows
7819
- for (int64_t i03 = 0; i03 < ne03; i03++) {
7820
- for (int64_t i02 = 0; i02 < ne02; i02++) {
7821
- id += rs * ir0;
7822
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
7823
- const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
7824
- memcpy(dst_ptr + id, src0_ptr, rs);
7825
- id += rs;
7826
- }
7827
- id += rs * (ne01 - ir1);
7828
- }
7829
- }
7830
- } else {
7831
- //printf("%s: this is not optimal - fix me\n", __func__);
7832
-
7833
- for (int64_t i03 = 0; i03 < ne03; i03++) {
7834
- for (int64_t i02 = 0; i02 < ne02; i02++) {
7835
- id += rs * ir0;
7836
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
7837
- for (int64_t i00 = 0; i00 < ne00; i00++) {
7838
- const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
7839
- memcpy(dst_ptr + id, src0_ptr, type_size);
7840
 
7841
- id += type_size;
 
 
 
 
 
 
7842
  }
 
7843
  }
7844
- id += rs * (ne01 - ir1);
7845
  }
7846
- }
7847
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7848
 
7849
  return;
7850
  }
@@ -7925,6 +8547,10 @@ static void ggml_compute_forward_dup(
7925
  {
7926
  ggml_compute_forward_dup_f16(params, dst);
7927
  } break;
 
 
 
 
7928
  case GGML_TYPE_F32:
7929
  {
7930
  ggml_compute_forward_dup_f32(params, dst);
@@ -8018,17 +8644,96 @@ static void ggml_compute_forward_add_f32(
8018
  float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
8019
  float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
8020
 
8021
- for (int64_t i0 = 0; i0 < ne0; ++i0) {
8022
- const int64_t i10 = i0 % ne10;
8023
- float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8024
 
8025
- dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
 
 
8026
  }
8027
  }
8028
  }
 
 
 
 
8029
  }
8030
 
8031
- static void ggml_compute_forward_add_f16_f32(
8032
  const struct ggml_compute_params * params,
8033
  struct ggml_tensor * dst) {
8034
 
@@ -8048,18 +8753,18 @@ static void ggml_compute_forward_add_f16_f32(
8048
 
8049
  GGML_TENSOR_BINARY_OP_LOCALS
8050
 
8051
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
8052
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
8053
 
8054
  if (dst->type == GGML_TYPE_F32) {
8055
  GGML_ASSERT( nb0 == sizeof(float));
8056
  }
8057
  else {
8058
- GGML_ASSERT(dst->type == GGML_TYPE_F16);
8059
- GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
8060
  }
8061
 
8062
- GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
8063
 
8064
  // rows per thread
8065
  const int dr = (nr + nth - 1)/nth;
@@ -8069,19 +8774,19 @@ static void ggml_compute_forward_add_f16_f32(
8069
  const int ir1 = MIN(ir0 + dr, nr);
8070
 
8071
  if (nb10 == sizeof(float)) {
8072
- if (dst->type == GGML_TYPE_F16) {
8073
  for (int ir = ir0; ir < ir1; ++ir) {
8074
  // src0, src1 and dst are same shape => same indices
8075
  const int i3 = ir/(ne2*ne1);
8076
  const int i2 = (ir - i3*ne2*ne1)/ne1;
8077
  const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8078
 
8079
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
8080
- ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8081
  float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
8082
 
8083
  for (int i = 0; i < ne0; i++) {
8084
- dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
8085
  }
8086
  }
8087
  } else {
@@ -8092,11 +8797,11 @@ static void ggml_compute_forward_add_f16_f32(
8092
  const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8093
 
8094
  float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
8095
- ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8096
  float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
8097
 
8098
  for (int i = 0; i < ne0; i++) {
8099
- dst_ptr[i] = GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
8100
  }
8101
  }
8102
  }
@@ -8163,6 +8868,62 @@ static void ggml_compute_forward_add_f16_f16(
8163
  }
8164
  }
8165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8166
  static void ggml_compute_forward_add_q_f32(
8167
  const struct ggml_compute_params * params,
8168
  struct ggml_tensor * dst) {
@@ -8272,6 +9033,18 @@ static void ggml_compute_forward_add(
8272
  GGML_ASSERT(false);
8273
  }
8274
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
8275
  case GGML_TYPE_Q4_0:
8276
  case GGML_TYPE_Q4_1:
8277
  case GGML_TYPE_Q5_0:
@@ -8530,6 +9303,110 @@ static void ggml_compute_forward_add1_q_f32(
8530
  }
8531
  }
8532
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8533
  static void ggml_compute_forward_add1(
8534
  const struct ggml_compute_params * params,
8535
  struct ggml_tensor * dst) {
@@ -8554,6 +9431,18 @@ static void ggml_compute_forward_add1(
8554
  GGML_ASSERT(false);
8555
  }
8556
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
8557
  case GGML_TYPE_Q4_0:
8558
  case GGML_TYPE_Q4_1:
8559
  case GGML_TYPE_Q5_0:
@@ -8682,6 +9571,7 @@ static void ggml_compute_forward_acc(
8682
  ggml_compute_forward_acc_f32(params, dst);
8683
  } break;
8684
  case GGML_TYPE_F16:
 
8685
  case GGML_TYPE_Q4_0:
8686
  case GGML_TYPE_Q4_1:
8687
  case GGML_TYPE_Q5_0:
@@ -9203,6 +10093,40 @@ static void ggml_compute_forward_sum_f16(
9203
  ((ggml_fp16_t *) dst->data)[0] = GGML_FP32_TO_FP16(sum);
9204
  }
9205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9206
  static void ggml_compute_forward_sum(
9207
  const struct ggml_compute_params * params,
9208
  struct ggml_tensor * dst) {
@@ -9218,6 +10142,10 @@ static void ggml_compute_forward_sum(
9218
  {
9219
  ggml_compute_forward_sum_f16(params, dst);
9220
  } break;
 
 
 
 
9221
  default:
9222
  {
9223
  GGML_ASSERT(false);
@@ -9492,6 +10420,7 @@ static void ggml_compute_forward_repeat(
9492
 
9493
  switch (src0->type) {
9494
  case GGML_TYPE_F16:
 
9495
  case GGML_TYPE_I16:
9496
  {
9497
  ggml_compute_forward_repeat_f16(params, dst);
@@ -11855,6 +12784,7 @@ static void ggml_compute_forward_set(
11855
  ggml_compute_forward_set_f32(params, dst);
11856
  } break;
11857
  case GGML_TYPE_F16:
 
11858
  case GGML_TYPE_Q4_0:
11859
  case GGML_TYPE_Q4_1:
11860
  case GGML_TYPE_Q5_0:
@@ -12029,6 +12959,49 @@ static void ggml_compute_forward_get_rows_f16(
12029
  }
12030
  }
12031
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12032
  static void ggml_compute_forward_get_rows_f32(
12033
  const struct ggml_compute_params * params,
12034
  struct ggml_tensor * dst) {
@@ -12106,6 +13079,10 @@ static void ggml_compute_forward_get_rows(
12106
  {
12107
  ggml_compute_forward_get_rows_f16(params, dst);
12108
  } break;
 
 
 
 
12109
  case GGML_TYPE_F32:
12110
  case GGML_TYPE_I32:
12111
  {
@@ -12801,6 +13778,7 @@ static void ggml_compute_forward_alibi(
12801
  {
12802
  ggml_compute_forward_alibi_f32(params, dst);
12803
  } break;
 
12804
  case GGML_TYPE_Q4_0:
12805
  case GGML_TYPE_Q4_1:
12806
  case GGML_TYPE_Q5_0:
@@ -12890,6 +13868,7 @@ static void ggml_compute_forward_clamp(
12890
  ggml_compute_forward_clamp_f32(params, dst);
12891
  } break;
12892
  case GGML_TYPE_F16:
 
12893
  case GGML_TYPE_Q4_0:
12894
  case GGML_TYPE_Q4_1:
12895
  case GGML_TYPE_Q5_0:
@@ -15987,6 +16966,7 @@ static void ggml_compute_forward_get_rel_pos(
15987
 
15988
  switch (src0->type) {
15989
  case GGML_TYPE_F16:
 
15990
  {
15991
  ggml_compute_forward_get_rel_pos_f16(params, dst);
15992
  } break;
@@ -18856,7 +19836,10 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18856
  case GGML_OP_CPY:
18857
  case GGML_OP_DUP:
18858
  {
18859
- if (ggml_is_quantized(node->type)) {
 
 
 
18860
  cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
18861
  }
18862
  } break;
@@ -18935,7 +19918,8 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18935
  const int64_t ne10 = node->src[1]->ne[0]; // L
18936
  const int64_t ne11 = node->src[1]->ne[1]; // Cin
18937
 
18938
- if (node->src[0]->type == GGML_TYPE_F16 &&
 
18939
  node->src[1]->type == GGML_TYPE_F32) {
18940
  cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02;
18941
  cur += sizeof(ggml_fp16_t)*ne10*ne11;
@@ -18971,6 +19955,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18971
  } else if (node->src[1]->type == GGML_TYPE_F16) {
18972
  cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
18973
  cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
 
 
 
18974
  }
18975
  } break;
18976
  case GGML_OP_FLASH_ATTN_EXT:
@@ -18987,6 +19974,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18987
  } else if (node->src[1]->type == GGML_TYPE_F16) {
18988
  cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
18989
  cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
 
 
 
18990
  }
18991
  } break;
18992
  case GGML_OP_FLASH_ATTN_BACK:
@@ -19000,6 +19990,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
19000
  } else if (node->src[1]->type == GGML_TYPE_F16) {
19001
  cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
19002
  cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
 
 
 
19003
  }
19004
  } break;
19005
 
@@ -19776,7 +20769,9 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
19776
  if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) {
19777
  fprintf(fp, "%d", ggml_get_i32_1d(node, j));
19778
  }
19779
- else if (node->type == GGML_TYPE_F32 || node->type == GGML_TYPE_F16) {
 
 
19780
  fprintf(fp, "%.1e", (double)ggml_get_f32_1d(node, j));
19781
  }
19782
  else {
@@ -20834,6 +21829,12 @@ size_t ggml_quantize_chunk(
20834
  ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n);
20835
  result = n * elemsize;
20836
  } break;
 
 
 
 
 
 
20837
  case GGML_TYPE_F32:
20838
  {
20839
  size_t elemsize = sizeof(float);
 
322
  // precomputed f32 table for f16 (256 KB) (ggml-impl.h)
323
  float ggml_table_f32_f16[1 << 16];
324
 
325
+ GGML_CALL const char * ggml_status_to_string(enum ggml_status status) {
326
  switch (status) {
327
  case GGML_STATUS_ALLOC_FAILED: return "GGML status: error (failed to allocate memory)";
328
  case GGML_STATUS_FAILED: return "GGML status: error (operation failed)";
 
333
  return "GGML status: unknown";
334
  }
335
 
 
 
336
  float ggml_fp16_to_fp32(ggml_fp16_t x) {
337
+ #define ggml_fp16_to_fp32 do_not_use__ggml_fp16_to_fp32__in_ggml
338
  return GGML_FP16_TO_FP32(x);
339
  }
340
 
341
  ggml_fp16_t ggml_fp32_to_fp16(float x) {
342
+ #define ggml_fp32_to_fp16 do_not_use__ggml_fp32_to_fp16__in_ggml
343
  return GGML_FP32_TO_FP16(x);
344
  }
345
 
346
+ float ggml_bf16_to_fp32(ggml_bf16_t x) {
347
+ #define ggml_bf16_to_fp32 do_not_use__ggml_bf16_to_fp32__in_ggml
348
+ return GGML_BF16_TO_FP32(x); // it just left shifts
349
+ }
350
+
351
+ ggml_bf16_t ggml_fp32_to_bf16(float x) {
352
+ #define ggml_fp32_to_bf16 do_not_use__ggml_fp32_to_bf16__in_ggml
353
+ return GGML_FP32_TO_BF16(x);
354
+ }
355
+
356
  void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) {
357
  for (int64_t i = 0; i < n; i++) {
358
  y[i] = GGML_FP16_TO_FP32(x[i]);
 
378
  }
379
  }
380
 
381
+ void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) {
382
+ int64_t i = 0;
383
+ #if defined(__AVX512F__)
384
+ for (; i + 16 <= n; i += 16) {
385
+ _mm512_storeu_ps(y + i,
386
+ _mm512_castsi512_ps(
387
+ _mm512_slli_epi32(
388
+ _mm512_cvtepu16_epi32(
389
+ _mm256_loadu_si256(
390
+ (const __m256i *)(x + i))),
391
+ 16)));
392
+ }
393
+ #elif defined(__AVX2__)
394
+ for (; i + 8 <= n; i += 8) {
395
+ _mm256_storeu_ps(y + i,
396
+ _mm256_castsi256_ps(
397
+ _mm256_slli_epi32(
398
+ _mm256_cvtepu16_epi32(
399
+ _mm_loadu_si128(
400
+ (const __m128i *)(x + i))),
401
+ 16)));
402
+ }
403
+ #endif
404
+ for (; i < n; i++) {
405
+ y[i] = GGML_BF16_TO_FP32(x[i]);
406
+ }
407
+ }
408
+
409
+ void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {
410
+ int i = 0;
411
+ #if defined(__AVX512BF16__)
412
+ for (; i + 32 <= n; i += 32) {
413
+ _mm512_storeu_ps(
414
+ (__m512 *)(y + i),
415
+ (__m512)_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
416
+ _mm512_loadu_ps(x + i)));
417
+ }
418
+ #endif
419
+ for (; i < n; i++) {
420
+ y[i] = GGML_FP32_TO_BF16(x[i]);
421
+ }
422
+ }
423
+
424
  bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b) {
425
  return memcmp(guid_a, guid_b, sizeof(ggml_guid)) == 0;
426
  }
 
556
 
557
  static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc);
558
  static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc);
559
+ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc);
560
 
561
  static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
562
  [GGML_TYPE_I8] = {
 
899
  .type_size = sizeof(block_q8_K),
900
  .is_quantized = true,
901
  .from_float = quantize_row_q8_K,
902
+ },
903
+ [GGML_TYPE_BF16] = {
904
+ .type_name = "bf16",
905
+ .blck_size = 1,
906
+ .type_size = sizeof(ggml_bf16_t),
907
+ .is_quantized = false,
908
+ .to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row,
909
+ .from_float = (ggml_from_float_t) ggml_fp32_to_bf16_row,
910
+ .from_float_reference = (ggml_from_float_t) ggml_fp32_to_bf16_row,
911
+ .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16,
912
+ .vec_dot_type = GGML_TYPE_BF16,
913
+ .nrows = 1,
914
  }
915
  };
916
 
 
1546
 
1547
  inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1548
 
1549
+ inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1550
+
1551
  inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
1552
  inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; }
1553
  inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
 
1566
  UNUSED(by);
1567
  UNUSED(bs);
1568
 
1569
+ #if defined(GGML_SIMD)
1570
  float sumf = 0.0f;
1571
  const int np = (n & ~(GGML_F32_STEP - 1));
1572
 
 
1602
  *s = sumf;
1603
  }
1604
 
1605
+ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc) {
1606
+ assert(nrc == 1);
1607
+ UNUSED(nrc);
1608
+ UNUSED(bx);
1609
+ UNUSED(by);
1610
+ UNUSED(bs);
1611
+ int i = 0;
1612
+ ggml_float sumf = 0;
1613
+
1614
+ #if defined(__AVX512BF16__)
1615
+ __m512 c1 = _mm512_setzero_ps();
1616
+ __m512 c2 = _mm512_setzero_ps();
1617
+ for (; i + 64 <= n; i += 64) {
1618
+ c1 = _mm512_dpbf16_ps(c1, (__m512bh)_mm512_loadu_ps((const float *)(x + i)),
1619
+ (__m512bh)_mm512_loadu_ps((const float *)(y + i)));
1620
+ c2 = _mm512_dpbf16_ps(c2, (__m512bh)_mm512_loadu_ps((const float *)(x + i + 32)),
1621
+ (__m512bh)_mm512_loadu_ps((const float *)(y + i + 32)));
1622
+ }
1623
+ sumf += (ggml_float)_mm512_reduce_add_ps(c1);
1624
+ sumf += (ggml_float)_mm512_reduce_add_ps(c2);
1625
+
1626
+ #elif defined(__AVX512F__)
1627
+ #define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16))
1628
+ __m512 c1 = _mm512_setzero_ps();
1629
+ __m512 c2 = _mm512_setzero_ps();
1630
+ for (; i + 32 <= n; i += 32) {
1631
+ c1 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
1632
+ c2 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c2);
1633
+ }
1634
+ sumf += (ggml_float)_mm512_reduce_add_ps(c1);
1635
+ sumf += (ggml_float)_mm512_reduce_add_ps(c2);
1636
+
1637
+ #undef LOAD
1638
+ #elif defined(__AVX2__)
1639
+ #define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16))
1640
+ __m256 c1 = _mm256_setzero_ps();
1641
+ __m256 c2 = _mm256_setzero_ps();
1642
+ __m256 c3 = _mm256_setzero_ps();
1643
+ __m256 c4 = _mm256_setzero_ps();
1644
+ for (; i + 32 <= n; i += 32) {
1645
+ c1 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
1646
+ c2 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 8), LOAD(y + i + 8)), c2);
1647
+ c3 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c3);
1648
+ c4 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 24), LOAD(y + i + 24)), c4);
1649
+ }
1650
+ __m128 g;
1651
+ c1 = _mm256_add_ps(_mm256_add_ps(c1, c3),
1652
+ _mm256_add_ps(c2, c4));
1653
+ g = _mm_add_ps(_mm256_extractf128_ps(c1, 1),
1654
+ _mm256_castps256_ps128(c1));
1655
+ g = _mm_add_ps(g, _mm_movehl_ps(g, g));
1656
+ g = _mm_add_ss(g, _mm_movehdup_ps(g));
1657
+ sumf += (ggml_float)_mm_cvtss_f32(g);
1658
+
1659
+ #undef LOAD
1660
+ #endif
1661
+
1662
+ for (; i < n; ++i) {
1663
+ sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) *
1664
+ GGML_BF16_TO_FP32(y[i]));
1665
+ }
1666
+ *s = sumf;
1667
+ }
1668
+
1669
  static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc) {
1670
  assert(nrc == 1);
1671
  UNUSED(nrc);
 
2100
  *s = sum;
2101
  }
2102
 
2103
+ inline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16_t * x) {
2104
+ float sum = 0.0f;
2105
+ for (int i = 0; i < n; ++i) {
2106
+ sum += GGML_BF16_TO_FP32(x[i]);
2107
+ }
2108
+ *s = sum;
2109
+ }
2110
+
2111
  inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
2112
  #ifndef GGML_USE_ACCELERATE
2113
  float max = -INFINITY;
 
2519
  // figure out which node we're on
2520
  uint current_cpu;
2521
  int getcpu_ret = 0;
2522
+ #if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28) || defined(__COSMOPOLITAN__)
2523
  getcpu_ret = getcpu(&current_cpu, &g_state.numa.current_node);
2524
  #else
2525
  // old glibc doesn't have a wrapper for this call. Fall back on direct syscall
 
2730
  switch (ftype) {
2731
  case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break;
2732
  case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break;
2733
+ case GGML_FTYPE_MOSTLY_BF16: wtype = GGML_TYPE_BF16; break;
2734
  case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break;
2735
  case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break;
2736
  case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break;
 
2872
  {
2873
  const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
2874
 
 
2875
  for (int i = 0; i < (1 << 16); ++i) {
2876
+ union {
2877
+ uint16_t u16;
2878
+ ggml_fp16_t fp16;
2879
+ } u = {i};
2880
+ float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
2881
  ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
2882
  ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
2883
  ggml_table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
2884
+ ggml_table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
2885
  }
2886
 
2887
  const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
 
3345
  ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
3346
  }
3347
  } break;
3348
+ case GGML_TYPE_BF16:
3349
+ {
3350
+ assert(tensor->nb[0] == sizeof(ggml_fp16_t));
3351
+ for (int i = 0; i < n; i++) {
3352
+ ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value));
3353
+ }
3354
+ } break;
3355
  case GGML_TYPE_F32:
3356
  {
3357
  assert(tensor->nb[0] == sizeof(float));
 
3404
  ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
3405
  }
3406
  } break;
3407
+ case GGML_TYPE_BF16:
3408
+ {
3409
+ assert(tensor->nb[0] == sizeof(ggml_bf16_t));
3410
+ for (int i = 0; i < n; i++) {
3411
+ ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value));
3412
+ }
3413
+ } break;
3414
  case GGML_TYPE_F32:
3415
  {
3416
  assert(tensor->nb[0] == sizeof(float));
 
3478
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3479
  return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
3480
  }
3481
+ case GGML_TYPE_BF16:
3482
+ {
3483
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
3484
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]);
3485
+ }
3486
  case GGML_TYPE_F32:
3487
  {
3488
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
 
3525
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3526
  ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
3527
  } break;
3528
+ case GGML_TYPE_BF16:
3529
+ {
3530
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
3531
+ ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value);
3532
+ } break;
3533
  case GGML_TYPE_F32:
3534
  {
3535
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
 
3553
  return ((int32_t *) data)[0];
3554
  case GGML_TYPE_F16:
3555
  return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
3556
+ case GGML_TYPE_BF16:
3557
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]);
3558
  case GGML_TYPE_F32:
3559
  return ((float *) data)[0];
3560
  default:
 
3583
  {
3584
  ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
3585
  } break;
3586
+ case GGML_TYPE_BF16:
3587
+ {
3588
+ ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value);
3589
+ } break;
3590
  case GGML_TYPE_F32:
3591
  {
3592
  ((float *)(data))[0] = value;
 
3625
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3626
  return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
3627
  }
3628
+ case GGML_TYPE_BF16:
3629
+ {
3630
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
3631
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]);
3632
+ }
3633
  case GGML_TYPE_F32:
3634
  {
3635
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
 
3672
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3673
  ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
3674
  } break;
3675
+ case GGML_TYPE_BF16:
3676
+ {
3677
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
3678
+ ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value);
3679
+ } break;
3680
  case GGML_TYPE_F32:
3681
  {
3682
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
 
3700
  return ((int32_t *) data)[0];
3701
  case GGML_TYPE_F16:
3702
  return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
3703
+ case GGML_TYPE_BF16:
3704
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]);
3705
  case GGML_TYPE_F32:
3706
  return ((float *) data)[0];
3707
  default:
 
3730
  {
3731
  ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
3732
  } break;
3733
+ case GGML_TYPE_BF16:
3734
+ {
3735
+ ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value);
3736
+ } break;
3737
  case GGML_TYPE_F32:
3738
  {
3739
  ((float *)(data))[0] = value;
 
3928
  // TODO: support less-strict constraint
3929
  // GGML_ASSERT(ggml_can_repeat(b, a));
3930
  GGML_ASSERT(ggml_can_repeat_rows(b, a));
3931
+
3932
+ // currently only supported for quantized input and f16
3933
+ GGML_ASSERT(ggml_is_quantized(a->type) ||
3934
+ a->type == GGML_TYPE_F16 ||
3935
+ a->type == GGML_TYPE_BF16);
3936
 
3937
  bool is_node = false;
3938
 
 
7423
  ((char *) src0->data + ie0*nb00),
7424
  (ie1 - ie0) * ggml_type_size(src0->type));
7425
  }
 
7426
  }
7427
+
7428
  static void ggml_compute_forward_dup_f16(
7429
  const struct ggml_compute_params * params,
7430
  struct ggml_tensor * dst) {
 
7698
  }
7699
  }
7700
 
7701
+ static void ggml_compute_forward_dup_bf16(
7702
  const struct ggml_compute_params * params,
7703
  struct ggml_tensor * dst) {
7704
 
 
7746
  return;
7747
  }
7748
 
7749
+ // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
7750
+
7751
  if (ggml_is_contiguous(dst)) {
7752
+ if (nb00 == sizeof(ggml_bf16_t)) {
7753
+ if (dst->type == GGML_TYPE_BF16) {
 
7754
  size_t id = 0;
7755
  const size_t rs = ne00 * nb00;
7756
  char * dst_ptr = (char *) dst->data;
 
7766
  id += rs * (ne01 - ir1);
7767
  }
7768
  }
7769
+ } else if (dst->type == GGML_TYPE_F16) {
7770
+ size_t id = 0;
7771
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
7772
+
7773
+ for (int i03 = 0; i03 < ne03; i03++) {
7774
+ for (int i02 = 0; i02 < ne02; i02++) {
7775
+ id += ne00 * ir0;
7776
+ for (int i01 = ir0; i01 < ir1; i01++) {
7777
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7778
+ for (int i00 = 0; i00 < ne00; i00++) {
7779
+ dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00]));
7780
+ id++;
7781
+ }
7782
+ }
7783
+ id += ne00 * (ne01 - ir1);
7784
+ }
7785
+ }
7786
+ } else if (dst->type == GGML_TYPE_F32) {
7787
+ size_t id = 0;
7788
+ float * dst_ptr = (float *) dst->data;
7789
+
7790
+ for (int i03 = 0; i03 < ne03; i03++) {
7791
+ for (int i02 = 0; i02 < ne02; i02++) {
7792
+ id += ne00 * ir0;
7793
+ for (int i01 = ir0; i01 < ir1; i01++) {
7794
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7795
+ for (int i00 = 0; i00 < ne00; i00++) {
7796
+ dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]);
7797
+ id++;
7798
+ }
7799
+ }
7800
+ id += ne00 * (ne01 - ir1);
7801
+ }
7802
+ }
7803
  } else if (type_traits[dst->type].from_float) {
7804
  ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
7805
+ float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
7806
 
7807
  size_t id = 0;
7808
  size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
 
7812
  for (int i02 = 0; i02 < ne02; i02++) {
7813
  id += rs * ir0;
7814
  for (int i01 = ir0; i01 < ir1; i01++) {
7815
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7816
+
7817
+ for (int i00 = 0; i00 < ne00; i00++) {
7818
+ src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]);
7819
+ }
7820
+
7821
+ quantize_row_q(src0_f32, dst_ptr + id, ne00);
7822
  id += rs;
7823
  }
7824
  id += rs * (ne01 - ir1);
 
7839
  id += ne00 * ir0;
7840
  for (int i01 = ir0; i01 < ir1; i01++) {
7841
  for (int i00 = 0; i00 < ne00; i00++) {
7842
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7843
+
7844
+ dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr);
7845
+ id++;
7846
+ }
7847
+ }
7848
+ id += ne00 * (ne01 - ir1);
7849
+ }
7850
+ }
7851
+ } else if (dst->type == GGML_TYPE_BF16) {
7852
+ size_t id = 0;
7853
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
7854
+
7855
+ for (int i03 = 0; i03 < ne03; i03++) {
7856
+ for (int i02 = 0; i02 < ne02; i02++) {
7857
+ id += ne00 * ir0;
7858
+ for (int i01 = ir0; i01 < ir1; i01++) {
7859
+ for (int i00 = 0; i00 < ne00; i00++) {
7860
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7861
 
7862
  dst_ptr[id] = *src0_ptr;
7863
  id++;
 
7875
  id += ne00 * ir0;
7876
  for (int i01 = ir0; i01 < ir1; i01++) {
7877
  for (int i00 = 0; i00 < ne00; i00++) {
7878
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7879
 
7880
+ dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr));
7881
  id++;
7882
  }
7883
  }
 
7888
  GGML_ASSERT(false); // TODO: implement
7889
  }
7890
  }
 
7891
  return;
7892
  }
7893
 
7894
  // dst counters
 
7895
  int64_t i10 = 0;
7896
  int64_t i11 = 0;
7897
  int64_t i12 = 0;
7898
  int64_t i13 = 0;
7899
 
7900
+ if (dst->type == GGML_TYPE_BF16) {
7901
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7902
  for (int64_t i02 = 0; i02 < ne02; i02++) {
7903
  i10 += ne00 * ir0;
 
7918
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7919
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
7920
 
7921
+ memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t));
7922
+
7923
+ if (++i10 == ne00) {
7924
+ i10 = 0;
7925
+ if (++i11 == ne01) {
7926
+ i11 = 0;
7927
+ if (++i12 == ne02) {
7928
+ i12 = 0;
7929
+ if (++i13 == ne03) {
7930
+ i13 = 0;
7931
+ }
7932
+ }
7933
+ }
7934
+ }
7935
+ }
7936
+ }
7937
+ i10 += ne00 * (ne01 - ir1);
7938
+ while (i10 >= ne0) {
7939
+ i10 -= ne0;
7940
+ if (++i11 == ne1) {
7941
+ i11 = 0;
7942
+ if (++i12 == ne2) {
7943
+ i12 = 0;
7944
+ if (++i13 == ne3) {
7945
+ i13 = 0;
7946
+ }
7947
+ }
7948
+ }
7949
+ }
7950
+ }
7951
+ }
7952
+ } else if (dst->type == GGML_TYPE_F16) {
7953
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
7954
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
7955
+ i10 += ne00 * ir0;
7956
+ while (i10 >= ne0) {
7957
+ i10 -= ne0;
7958
+ if (++i11 == ne1) {
7959
+ i11 = 0;
7960
+ if (++i12 == ne2) {
7961
+ i12 = 0;
7962
+ if (++i13 == ne3) {
7963
+ i13 = 0;
7964
+ }
7965
+ }
7966
+ }
7967
+ }
7968
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
7969
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
7970
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7971
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
7972
+
7973
+ *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr));
7974
 
7975
  if (++i10 == ne0) {
7976
  i10 = 0;
 
8001
  }
8002
  }
8003
  }
8004
+ } else if (dst->type == GGML_TYPE_F32) {
8005
  for (int64_t i03 = 0; i03 < ne03; i03++) {
8006
  for (int64_t i02 = 0; i02 < ne02; i02++) {
8007
  i10 += ne00 * ir0;
 
8022
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8023
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
8024
 
8025
+ *(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr);
8026
 
8027
  if (++i10 == ne0) {
8028
  i10 = 0;
 
8058
  }
8059
  }
8060
 
8061
+ static void ggml_compute_forward_dup_f32(
 
8062
  const struct ggml_compute_params * params,
8063
  struct ggml_tensor * dst) {
8064
 
8065
  const struct ggml_tensor * src0 = dst->src[0];
8066
 
8067
  GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
 
8068
 
8069
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
8070
  return;
8071
  }
8072
 
8073
+ GGML_TENSOR_UNARY_OP_LOCALS
 
 
 
 
 
8074
 
 
8075
  const int ith = params->ith; // thread index
8076
  const int nth = params->nth; // number of threads
8077
 
8078
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
8079
+ ggml_compute_forward_dup_same_cont(params, dst);
8080
+ return;
8081
+ }
8082
 
8083
  // parallelize by rows
8084
  const int nr = ne01;
 
8090
 
8091
  if (src0->type == dst->type &&
8092
  ne00 == ne0 &&
8093
+ nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
8094
  // copy by rows
8095
+ const size_t rs = ne00*nb00;
8096
  for (int64_t i03 = 0; i03 < ne03; i03++) {
8097
  for (int64_t i02 = 0; i02 < ne02; i02++) {
8098
  for (int64_t i01 = ir0; i01 < ir1; i01++) {
 
8107
  }
8108
 
8109
  if (ggml_is_contiguous(dst)) {
8110
+ // TODO: simplify
8111
+ if (nb00 == sizeof(float)) {
8112
+ if (dst->type == GGML_TYPE_F32) {
8113
+ size_t id = 0;
8114
+ const size_t rs = ne00 * nb00;
8115
+ char * dst_ptr = (char *) dst->data;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8116
 
8117
+ for (int i03 = 0; i03 < ne03; i03++) {
8118
+ for (int i02 = 0; i02 < ne02; i02++) {
8119
+ id += rs * ir0;
8120
+ for (int i01 = ir0; i01 < ir1; i01++) {
8121
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
8122
+ memcpy(dst_ptr + id, src0_ptr, rs);
8123
+ id += rs;
8124
  }
8125
+ id += rs * (ne01 - ir1);
8126
  }
 
8127
  }
8128
+ } else if (type_traits[dst->type].from_float) {
8129
+ ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
8130
+
8131
+ size_t id = 0;
8132
+ size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
8133
+ char * dst_ptr = (char *) dst->data;
8134
+
8135
+ for (int i03 = 0; i03 < ne03; i03++) {
8136
+ for (int i02 = 0; i02 < ne02; i02++) {
8137
+ id += rs * ir0;
8138
+ for (int i01 = ir0; i01 < ir1; i01++) {
8139
+ const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
8140
+ quantize_row_q(src0_ptr, dst_ptr + id, ne00);
8141
+ id += rs;
8142
+ }
8143
+ id += rs * (ne01 - ir1);
8144
+ }
8145
+ }
8146
+ } else {
8147
+ GGML_ASSERT(false); // TODO: implement
8148
+ }
8149
+ } else {
8150
+ //printf("%s: this is not optimal - fix me\n", __func__);
8151
+
8152
+ if (dst->type == GGML_TYPE_F32) {
8153
+ size_t id = 0;
8154
+ float * dst_ptr = (float *) dst->data;
8155
+
8156
+ for (int i03 = 0; i03 < ne03; i03++) {
8157
+ for (int i02 = 0; i02 < ne02; i02++) {
8158
+ id += ne00 * ir0;
8159
+ for (int i01 = ir0; i01 < ir1; i01++) {
8160
+ for (int i00 = 0; i00 < ne00; i00++) {
8161
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8162
+
8163
+ dst_ptr[id] = *src0_ptr;
8164
+ id++;
8165
+ }
8166
+ }
8167
+ id += ne00 * (ne01 - ir1);
8168
+ }
8169
+ }
8170
+ } else if (dst->type == GGML_TYPE_F16) {
8171
+ size_t id = 0;
8172
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
8173
+
8174
+ for (int i03 = 0; i03 < ne03; i03++) {
8175
+ for (int i02 = 0; i02 < ne02; i02++) {
8176
+ id += ne00 * ir0;
8177
+ for (int i01 = ir0; i01 < ir1; i01++) {
8178
+ for (int i00 = 0; i00 < ne00; i00++) {
8179
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8180
+
8181
+ dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
8182
+ id++;
8183
+ }
8184
+ }
8185
+ id += ne00 * (ne01 - ir1);
8186
+ }
8187
+ }
8188
+ } else if (dst->type == GGML_TYPE_BF16) {
8189
+ size_t id = 0;
8190
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
8191
+
8192
+ for (int i03 = 0; i03 < ne03; i03++) {
8193
+ for (int i02 = 0; i02 < ne02; i02++) {
8194
+ id += ne00 * ir0;
8195
+ for (int i01 = ir0; i01 < ir1; i01++) {
8196
+ for (int i00 = 0; i00 < ne00; i00++) {
8197
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8198
+
8199
+ dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr);
8200
+ id++;
8201
+ }
8202
+ }
8203
+ id += ne00 * (ne01 - ir1);
8204
+ }
8205
+ }
8206
+ } else {
8207
+ GGML_ASSERT(false); // TODO: implement
8208
+ }
8209
+ }
8210
+
8211
+ return;
8212
+ }
8213
+
8214
+ // dst counters
8215
+
8216
+ int64_t i10 = 0;
8217
+ int64_t i11 = 0;
8218
+ int64_t i12 = 0;
8219
+ int64_t i13 = 0;
8220
+
8221
+ if (dst->type == GGML_TYPE_F32) {
8222
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8223
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8224
+ i10 += ne00 * ir0;
8225
+ while (i10 >= ne0) {
8226
+ i10 -= ne0;
8227
+ if (++i11 == ne1) {
8228
+ i11 = 0;
8229
+ if (++i12 == ne2) {
8230
+ i12 = 0;
8231
+ if (++i13 == ne3) {
8232
+ i13 = 0;
8233
+ }
8234
+ }
8235
+ }
8236
+ }
8237
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8238
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8239
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8240
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
8241
+
8242
+ memcpy(dst_ptr, src0_ptr, sizeof(float));
8243
+
8244
+ if (++i10 == ne0) {
8245
+ i10 = 0;
8246
+ if (++i11 == ne1) {
8247
+ i11 = 0;
8248
+ if (++i12 == ne2) {
8249
+ i12 = 0;
8250
+ if (++i13 == ne3) {
8251
+ i13 = 0;
8252
+ }
8253
+ }
8254
+ }
8255
+ }
8256
+ }
8257
+ }
8258
+ i10 += ne00 * (ne01 - ir1);
8259
+ while (i10 >= ne0) {
8260
+ i10 -= ne0;
8261
+ if (++i11 == ne1) {
8262
+ i11 = 0;
8263
+ if (++i12 == ne2) {
8264
+ i12 = 0;
8265
+ if (++i13 == ne3) {
8266
+ i13 = 0;
8267
+ }
8268
+ }
8269
+ }
8270
+ }
8271
+ }
8272
+ }
8273
+ } else if (dst->type == GGML_TYPE_F16) {
8274
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8275
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8276
+ i10 += ne00 * ir0;
8277
+ while (i10 >= ne0) {
8278
+ i10 -= ne0;
8279
+ if (++i11 == ne1) {
8280
+ i11 = 0;
8281
+ if (++i12 == ne2) {
8282
+ i12 = 0;
8283
+ if (++i13 == ne3) {
8284
+ i13 = 0;
8285
+ }
8286
+ }
8287
+ }
8288
+ }
8289
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8290
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8291
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8292
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
8293
+
8294
+ *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
8295
+
8296
+ if (++i10 == ne0) {
8297
+ i10 = 0;
8298
+ if (++i11 == ne1) {
8299
+ i11 = 0;
8300
+ if (++i12 == ne2) {
8301
+ i12 = 0;
8302
+ if (++i13 == ne3) {
8303
+ i13 = 0;
8304
+ }
8305
+ }
8306
+ }
8307
+ }
8308
+ }
8309
+ }
8310
+ i10 += ne00 * (ne01 - ir1);
8311
+ while (i10 >= ne0) {
8312
+ i10 -= ne0;
8313
+ if (++i11 == ne1) {
8314
+ i11 = 0;
8315
+ if (++i12 == ne2) {
8316
+ i12 = 0;
8317
+ if (++i13 == ne3) {
8318
+ i13 = 0;
8319
+ }
8320
+ }
8321
+ }
8322
+ }
8323
+ }
8324
+ }
8325
+ } else if (dst->type == GGML_TYPE_BF16) {
8326
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8327
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8328
+ i10 += ne00 * ir0;
8329
+ while (i10 >= ne0) {
8330
+ i10 -= ne0;
8331
+ if (++i11 == ne1) {
8332
+ i11 = 0;
8333
+ if (++i12 == ne2) {
8334
+ i12 = 0;
8335
+ if (++i13 == ne3) {
8336
+ i13 = 0;
8337
+ }
8338
+ }
8339
+ }
8340
+ }
8341
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8342
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8343
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8344
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
8345
+
8346
+ *(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr);
8347
+
8348
+ if (++i10 == ne0) {
8349
+ i10 = 0;
8350
+ if (++i11 == ne1) {
8351
+ i11 = 0;
8352
+ if (++i12 == ne2) {
8353
+ i12 = 0;
8354
+ if (++i13 == ne3) {
8355
+ i13 = 0;
8356
+ }
8357
+ }
8358
+ }
8359
+ }
8360
+ }
8361
+ }
8362
+ i10 += ne00 * (ne01 - ir1);
8363
+ while (i10 >= ne0) {
8364
+ i10 -= ne0;
8365
+ if (++i11 == ne1) {
8366
+ i11 = 0;
8367
+ if (++i12 == ne2) {
8368
+ i12 = 0;
8369
+ if (++i13 == ne3) {
8370
+ i13 = 0;
8371
+ }
8372
+ }
8373
+ }
8374
+ }
8375
+ }
8376
+ }
8377
+ } else {
8378
+ GGML_ASSERT(false); // TODO: implement
8379
+ }
8380
+ }
8381
+
8382
+ // A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
8383
+ static void ggml_compute_forward_dup_bytes(
8384
+ const struct ggml_compute_params * params,
8385
+ struct ggml_tensor * dst) {
8386
+
8387
+ const struct ggml_tensor * src0 = dst->src[0];
8388
+
8389
+ GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
8390
+ GGML_ASSERT(src0->type == dst->type);
8391
+
8392
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
8393
+ return;
8394
+ }
8395
+
8396
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
8397
+ ggml_compute_forward_dup_same_cont(params, dst);
8398
+ return;
8399
+ }
8400
+
8401
+ GGML_TENSOR_UNARY_OP_LOCALS;
8402
+
8403
+ const size_t type_size = ggml_type_size(src0->type);
8404
+ const int ith = params->ith; // thread index
8405
+ const int nth = params->nth; // number of threads
8406
+
8407
+
8408
+ // parallelize by rows
8409
+ const int nr = ne01;
8410
+ // number of rows per thread
8411
+ const int dr = (nr + nth - 1) / nth;
8412
+ // row range for this thread
8413
+ const int ir0 = dr * ith;
8414
+ const int ir1 = MIN(ir0 + dr, nr);
8415
+
8416
+ if (src0->type == dst->type &&
8417
+ ne00 == ne0 &&
8418
+ nb00 == type_size && nb0 == type_size) {
8419
+ // copy by rows
8420
+ const size_t rs = ne00 * type_size;
8421
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8422
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8423
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8424
+ memcpy(
8425
+ ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
8426
+ ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
8427
+ rs);
8428
+ }
8429
+ }
8430
+ }
8431
+ return;
8432
+ }
8433
+
8434
+ if (ggml_is_contiguous(dst)) {
8435
+ size_t id = 0;
8436
+ char * dst_ptr = (char *) dst->data;
8437
+ const size_t rs = ne00 * type_size;
8438
+
8439
+ if (nb00 == type_size) {
8440
+ // src0 is contigous on first dimension, copy by rows
8441
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8442
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8443
+ id += rs * ir0;
8444
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8445
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
8446
+ memcpy(dst_ptr + id, src0_ptr, rs);
8447
+ id += rs;
8448
+ }
8449
+ id += rs * (ne01 - ir1);
8450
+ }
8451
+ }
8452
+ } else {
8453
+ //printf("%s: this is not optimal - fix me\n", __func__);
8454
+
8455
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8456
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8457
+ id += rs * ir0;
8458
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8459
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8460
+ const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
8461
+ memcpy(dst_ptr + id, src0_ptr, type_size);
8462
+
8463
+ id += type_size;
8464
+ }
8465
+ }
8466
+ id += rs * (ne01 - ir1);
8467
+ }
8468
+ }
8469
+ }
8470
 
8471
  return;
8472
  }
 
8547
  {
8548
  ggml_compute_forward_dup_f16(params, dst);
8549
  } break;
8550
+ case GGML_TYPE_BF16:
8551
+ {
8552
+ ggml_compute_forward_dup_bf16(params, dst);
8553
+ } break;
8554
  case GGML_TYPE_F32:
8555
  {
8556
  ggml_compute_forward_dup_f32(params, dst);
 
8644
  float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
8645
  float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
8646
 
8647
+ for (int64_t i0 = 0; i0 < ne0; ++i0) {
8648
+ const int64_t i10 = i0 % ne10;
8649
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
8650
+
8651
+ dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
8652
+ }
8653
+ }
8654
+ }
8655
+ }
8656
+
8657
+ static void ggml_compute_forward_add_f16_f32(
8658
+ const struct ggml_compute_params * params,
8659
+ struct ggml_tensor * dst) {
8660
+
8661
+ const struct ggml_tensor * src0 = dst->src[0];
8662
+ const struct ggml_tensor * src1 = dst->src[1];
8663
+
8664
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
8665
+
8666
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
8667
+ return;
8668
+ }
8669
+
8670
+ const int ith = params->ith;
8671
+ const int nth = params->nth;
8672
+
8673
+ const int nr = ggml_nrows(src0);
8674
+
8675
+ GGML_TENSOR_BINARY_OP_LOCALS
8676
+
8677
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
8678
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
8679
+
8680
+ if (dst->type == GGML_TYPE_F32) {
8681
+ GGML_ASSERT( nb0 == sizeof(float));
8682
+ }
8683
+ else {
8684
+ GGML_ASSERT(dst->type == GGML_TYPE_F16);
8685
+ GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
8686
+ }
8687
+
8688
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
8689
+
8690
+ // rows per thread
8691
+ const int dr = (nr + nth - 1)/nth;
8692
+
8693
+ // row range for this thread
8694
+ const int ir0 = dr*ith;
8695
+ const int ir1 = MIN(ir0 + dr, nr);
8696
+
8697
+ if (nb10 == sizeof(float)) {
8698
+ if (dst->type == GGML_TYPE_F16) {
8699
+ for (int ir = ir0; ir < ir1; ++ir) {
8700
+ // src0, src1 and dst are same shape => same indices
8701
+ const int i3 = ir/(ne2*ne1);
8702
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
8703
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8704
+
8705
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
8706
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8707
+ float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
8708
+
8709
+ for (int i = 0; i < ne0; i++) {
8710
+ dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
8711
+ }
8712
+ }
8713
+ } else {
8714
+ for (int ir = ir0; ir < ir1; ++ir) {
8715
+ // src0, src1 and dst are same shape => same indices
8716
+ const int i3 = ir/(ne2*ne1);
8717
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
8718
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8719
+
8720
+ float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
8721
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8722
+ float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
8723
 
8724
+ for (int i = 0; i < ne0; i++) {
8725
+ dst_ptr[i] = GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
8726
+ }
8727
  }
8728
  }
8729
  }
8730
+ else {
8731
+ // src1 is not contiguous
8732
+ GGML_ASSERT(false);
8733
+ }
8734
  }
8735
 
8736
+ static void ggml_compute_forward_add_bf16_f32(
8737
  const struct ggml_compute_params * params,
8738
  struct ggml_tensor * dst) {
8739
 
 
8753
 
8754
  GGML_TENSOR_BINARY_OP_LOCALS
8755
 
8756
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
8757
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
8758
 
8759
  if (dst->type == GGML_TYPE_F32) {
8760
  GGML_ASSERT( nb0 == sizeof(float));
8761
  }
8762
  else {
8763
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
8764
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
8765
  }
8766
 
8767
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
8768
 
8769
  // rows per thread
8770
  const int dr = (nr + nth - 1)/nth;
 
8774
  const int ir1 = MIN(ir0 + dr, nr);
8775
 
8776
  if (nb10 == sizeof(float)) {
8777
+ if (dst->type == GGML_TYPE_BF16) {
8778
  for (int ir = ir0; ir < ir1; ++ir) {
8779
  // src0, src1 and dst are same shape => same indices
8780
  const int i3 = ir/(ne2*ne1);
8781
  const int i2 = (ir - i3*ne2*ne1)/ne1;
8782
  const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8783
 
8784
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
8785
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8786
  float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
8787
 
8788
  for (int i = 0; i < ne0; i++) {
8789
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
8790
  }
8791
  }
8792
  } else {
 
8797
  const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8798
 
8799
  float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
8800
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8801
  float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
8802
 
8803
  for (int i = 0; i < ne0; i++) {
8804
+ dst_ptr[i] = GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
8805
  }
8806
  }
8807
  }
 
8868
  }
8869
  }
8870
 
8871
+ static void ggml_compute_forward_add_bf16_bf16(
8872
+ const struct ggml_compute_params * params,
8873
+ struct ggml_tensor * dst) {
8874
+
8875
+ const struct ggml_tensor * src0 = dst->src[0];
8876
+ const struct ggml_tensor * src1 = dst->src[1];
8877
+
8878
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
8879
+
8880
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
8881
+ return;
8882
+ }
8883
+
8884
+ const int ith = params->ith;
8885
+ const int nth = params->nth;
8886
+
8887
+ const int nr = ggml_nrows(src0);
8888
+
8889
+ GGML_TENSOR_BINARY_OP_LOCALS
8890
+
8891
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
8892
+ GGML_ASSERT(src1->type == GGML_TYPE_BF16);
8893
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
8894
+
8895
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
8896
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
8897
+
8898
+ // rows per thread
8899
+ const int dr = (nr + nth - 1)/nth;
8900
+
8901
+ // row range for this thread
8902
+ const int ir0 = dr*ith;
8903
+ const int ir1 = MIN(ir0 + dr, nr);
8904
+
8905
+ if (nb10 == sizeof(ggml_bf16_t)) {
8906
+ for (int ir = ir0; ir < ir1; ++ir) {
8907
+ // src0, src1 and dst are same shape => same indices
8908
+ const int i3 = ir/(ne2*ne1);
8909
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
8910
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8911
+
8912
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
8913
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8914
+ ggml_bf16_t * src1_ptr = (ggml_bf16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
8915
+
8916
+ for (int i = 0; i < ne0; i++) {
8917
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + GGML_BF16_TO_FP32(src1_ptr[i]));
8918
+ }
8919
+ }
8920
+ }
8921
+ else {
8922
+ // src1 is not contiguous
8923
+ GGML_ASSERT(false);
8924
+ }
8925
+ }
8926
+
8927
  static void ggml_compute_forward_add_q_f32(
8928
  const struct ggml_compute_params * params,
8929
  struct ggml_tensor * dst) {
 
9033
  GGML_ASSERT(false);
9034
  }
9035
  } break;
9036
+ case GGML_TYPE_BF16:
9037
+ {
9038
+ if (src1->type == GGML_TYPE_BF16) {
9039
+ ggml_compute_forward_add_bf16_bf16(params, dst);
9040
+ }
9041
+ else if (src1->type == GGML_TYPE_F32) {
9042
+ ggml_compute_forward_add_bf16_f32(params, dst);
9043
+ }
9044
+ else {
9045
+ GGML_ASSERT(false);
9046
+ }
9047
+ } break;
9048
  case GGML_TYPE_Q4_0:
9049
  case GGML_TYPE_Q4_1:
9050
  case GGML_TYPE_Q5_0:
 
9303
  }
9304
  }
9305
 
9306
+ static void ggml_compute_forward_add1_bf16_f32(
9307
+ const struct ggml_compute_params * params,
9308
+ struct ggml_tensor * dst) {
9309
+
9310
+ const struct ggml_tensor * src0 = dst->src[0];
9311
+ const struct ggml_tensor * src1 = dst->src[1];
9312
+
9313
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
9314
+ GGML_ASSERT(ggml_is_scalar(src1));
9315
+
9316
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
9317
+ return;
9318
+ }
9319
+
9320
+ // scalar to add
9321
+ const float v = *(float *) src1->data;
9322
+
9323
+ const int ith = params->ith;
9324
+ const int nth = params->nth;
9325
+
9326
+ const int nr = ggml_nrows(src0);
9327
+
9328
+ GGML_TENSOR_UNARY_OP_LOCALS
9329
+
9330
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
9331
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
9332
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
9333
+
9334
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
9335
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
9336
+
9337
+ // rows per thread
9338
+ const int dr = (nr + nth - 1)/nth;
9339
+
9340
+ // row range for this thread
9341
+ const int ir0 = dr*ith;
9342
+ const int ir1 = MIN(ir0 + dr, nr);
9343
+
9344
+ for (int ir = ir0; ir < ir1; ++ir) {
9345
+ // src0 and dst are same shape => same indices
9346
+ const int i3 = ir/(ne2*ne1);
9347
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
9348
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
9349
+
9350
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
9351
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
9352
+ for (int i = 0; i < ne0; i++) {
9353
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
9354
+ }
9355
+ }
9356
+ }
9357
+
9358
+ static void ggml_compute_forward_add1_bf16_bf16(
9359
+ const struct ggml_compute_params * params,
9360
+ struct ggml_tensor * dst) {
9361
+
9362
+ const struct ggml_tensor * src0 = dst->src[0];
9363
+ const struct ggml_tensor * src1 = dst->src[1];
9364
+
9365
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
9366
+ GGML_ASSERT(ggml_is_scalar(src1));
9367
+
9368
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
9369
+ return;
9370
+ }
9371
+
9372
+ // scalar to add
9373
+ const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data);
9374
+
9375
+ const int ith = params->ith;
9376
+ const int nth = params->nth;
9377
+
9378
+ const int nr = ggml_nrows(src0);
9379
+
9380
+ GGML_TENSOR_UNARY_OP_LOCALS
9381
+
9382
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
9383
+ GGML_ASSERT(src1->type == GGML_TYPE_BF16);
9384
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
9385
+
9386
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
9387
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
9388
+
9389
+ // rows per thread
9390
+ const int dr = (nr + nth - 1)/nth;
9391
+
9392
+ // row range for this thread
9393
+ const int ir0 = dr*ith;
9394
+ const int ir1 = MIN(ir0 + dr, nr);
9395
+
9396
+ for (int ir = ir0; ir < ir1; ++ir) {
9397
+ // src0 and dst are same shape => same indices
9398
+ const int i3 = ir/(ne2*ne1);
9399
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
9400
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
9401
+
9402
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
9403
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
9404
+ for (int i = 0; i < ne0; i++) {
9405
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
9406
+ }
9407
+ }
9408
+ }
9409
+
9410
  static void ggml_compute_forward_add1(
9411
  const struct ggml_compute_params * params,
9412
  struct ggml_tensor * dst) {
 
9431
  GGML_ASSERT(false);
9432
  }
9433
  } break;
9434
+ case GGML_TYPE_BF16:
9435
+ {
9436
+ if (src1->type == GGML_TYPE_BF16) {
9437
+ ggml_compute_forward_add1_bf16_bf16(params, dst);
9438
+ }
9439
+ else if (src1->type == GGML_TYPE_F32) {
9440
+ ggml_compute_forward_add1_bf16_f32(params, dst);
9441
+ }
9442
+ else {
9443
+ GGML_ASSERT(false);
9444
+ }
9445
+ } break;
9446
  case GGML_TYPE_Q4_0:
9447
  case GGML_TYPE_Q4_1:
9448
  case GGML_TYPE_Q5_0:
 
9571
  ggml_compute_forward_acc_f32(params, dst);
9572
  } break;
9573
  case GGML_TYPE_F16:
9574
+ case GGML_TYPE_BF16:
9575
  case GGML_TYPE_Q4_0:
9576
  case GGML_TYPE_Q4_1:
9577
  case GGML_TYPE_Q5_0:
 
10093
  ((ggml_fp16_t *) dst->data)[0] = GGML_FP32_TO_FP16(sum);
10094
  }
10095
 
10096
+ static void ggml_compute_forward_sum_bf16(
10097
+ const struct ggml_compute_params * params,
10098
+ struct ggml_tensor * dst) {
10099
+
10100
+ const struct ggml_tensor * src0 = dst->src[0];
10101
+
10102
+ assert(params->ith == 0);
10103
+ assert(ggml_is_scalar(dst));
10104
+
10105
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
10106
+ return;
10107
+ }
10108
+
10109
+ assert(src0->nb[0] == sizeof(ggml_bf16_t));
10110
+
10111
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
10112
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
10113
+
10114
+ float sum = 0;
10115
+ float row_sum = 0;
10116
+
10117
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
10118
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
10119
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
10120
+ ggml_vec_sum_bf16_ggf(ne00,
10121
+ &row_sum,
10122
+ (ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
10123
+ sum += row_sum;
10124
+ }
10125
+ }
10126
+ }
10127
+ ((ggml_bf16_t *) dst->data)[0] = GGML_FP32_TO_BF16(sum);
10128
+ }
10129
+
10130
  static void ggml_compute_forward_sum(
10131
  const struct ggml_compute_params * params,
10132
  struct ggml_tensor * dst) {
 
10142
  {
10143
  ggml_compute_forward_sum_f16(params, dst);
10144
  } break;
10145
+ case GGML_TYPE_BF16:
10146
+ {
10147
+ ggml_compute_forward_sum_bf16(params, dst);
10148
+ } break;
10149
  default:
10150
  {
10151
  GGML_ASSERT(false);
 
10420
 
10421
  switch (src0->type) {
10422
  case GGML_TYPE_F16:
10423
+ case GGML_TYPE_BF16:
10424
  case GGML_TYPE_I16:
10425
  {
10426
  ggml_compute_forward_repeat_f16(params, dst);
 
12784
  ggml_compute_forward_set_f32(params, dst);
12785
  } break;
12786
  case GGML_TYPE_F16:
12787
+ case GGML_TYPE_BF16:
12788
  case GGML_TYPE_Q4_0:
12789
  case GGML_TYPE_Q4_1:
12790
  case GGML_TYPE_Q5_0:
 
12959
  }
12960
  }
12961
 
12962
+ static void ggml_compute_forward_get_rows_bf16(
12963
+ const struct ggml_compute_params * params,
12964
+ struct ggml_tensor * dst) {
12965
+
12966
+ const struct ggml_tensor * src0 = dst->src[0];
12967
+ const struct ggml_tensor * src1 = dst->src[1];
12968
+
12969
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
12970
+ return;
12971
+ }
12972
+
12973
+ GGML_TENSOR_BINARY_OP_LOCALS
12974
+
12975
+ const int64_t nc = ne00;
12976
+ const int64_t nr = ggml_nelements(src1);
12977
+
12978
+ assert(ne0 == nc);
12979
+ assert(ne02 == ne11);
12980
+ assert(nb00 == sizeof(ggml_bf16_t));
12981
+ assert(ggml_nrows(dst) == nr);
12982
+
12983
+ const int ith = params->ith;
12984
+ const int nth = params->nth;
12985
+
12986
+ // rows per thread
12987
+ const int dr = (nr + nth - 1)/nth;
12988
+
12989
+ // row range for this thread
12990
+ const int ir0 = dr*ith;
12991
+ const int ir1 = MIN(ir0 + dr, nr);
12992
+
12993
+ for (int64_t i = ir0; i < ir1; ++i) {
12994
+ const int64_t i12 = i/(ne11*ne10);
12995
+ const int64_t i11 = (i - i12*ne11*ne10)/ne10;
12996
+ const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
12997
+ const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
12998
+
12999
+ ggml_bf16_to_fp32_row(
13000
+ (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
13001
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
13002
+ }
13003
+ }
13004
+
13005
  static void ggml_compute_forward_get_rows_f32(
13006
  const struct ggml_compute_params * params,
13007
  struct ggml_tensor * dst) {
 
13079
  {
13080
  ggml_compute_forward_get_rows_f16(params, dst);
13081
  } break;
13082
+ case GGML_TYPE_BF16:
13083
+ {
13084
+ ggml_compute_forward_get_rows_bf16(params, dst);
13085
+ } break;
13086
  case GGML_TYPE_F32:
13087
  case GGML_TYPE_I32:
13088
  {
 
13778
  {
13779
  ggml_compute_forward_alibi_f32(params, dst);
13780
  } break;
13781
+ case GGML_TYPE_BF16:
13782
  case GGML_TYPE_Q4_0:
13783
  case GGML_TYPE_Q4_1:
13784
  case GGML_TYPE_Q5_0:
 
13868
  ggml_compute_forward_clamp_f32(params, dst);
13869
  } break;
13870
  case GGML_TYPE_F16:
13871
+ case GGML_TYPE_BF16:
13872
  case GGML_TYPE_Q4_0:
13873
  case GGML_TYPE_Q4_1:
13874
  case GGML_TYPE_Q5_0:
 
16966
 
16967
  switch (src0->type) {
16968
  case GGML_TYPE_F16:
16969
+ case GGML_TYPE_BF16:
16970
  {
16971
  ggml_compute_forward_get_rel_pos_f16(params, dst);
16972
  } break;
 
19836
  case GGML_OP_CPY:
19837
  case GGML_OP_DUP:
19838
  {
19839
+ if (ggml_is_quantized(node->type) ||
19840
+ // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32
19841
+ (node->src[0]->type == GGML_TYPE_F16 && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) ||
19842
+ (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) {
19843
  cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
19844
  }
19845
  } break;
 
19918
  const int64_t ne10 = node->src[1]->ne[0]; // L
19919
  const int64_t ne11 = node->src[1]->ne[1]; // Cin
19920
 
19921
+ if ((node->src[0]->type == GGML_TYPE_F16 ||
19922
+ node->src[0]->type == GGML_TYPE_BF16) &&
19923
  node->src[1]->type == GGML_TYPE_F32) {
19924
  cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02;
19925
  cur += sizeof(ggml_fp16_t)*ne10*ne11;
 
19955
  } else if (node->src[1]->type == GGML_TYPE_F16) {
19956
  cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
19957
  cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
19958
+ } else if (node->src[1]->type == GGML_TYPE_BF16) {
19959
+ cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
19960
+ cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
19961
  }
19962
  } break;
19963
  case GGML_OP_FLASH_ATTN_EXT:
 
19974
  } else if (node->src[1]->type == GGML_TYPE_F16) {
19975
  cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
19976
  cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
19977
+ } else if (node->src[1]->type == GGML_TYPE_BF16) {
19978
+ cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
19979
+ cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
19980
  }
19981
  } break;
19982
  case GGML_OP_FLASH_ATTN_BACK:
 
19990
  } else if (node->src[1]->type == GGML_TYPE_F16) {
19991
  cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
19992
  cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
19993
+ } else if (node->src[1]->type == GGML_TYPE_BF16) {
19994
+ cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
19995
+ cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
19996
  }
19997
  } break;
19998
 
 
20769
  if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) {
20770
  fprintf(fp, "%d", ggml_get_i32_1d(node, j));
20771
  }
20772
+ else if (node->type == GGML_TYPE_F32 ||
20773
+ node->type == GGML_TYPE_F16 ||
20774
+ node->type == GGML_TYPE_BF16) {
20775
  fprintf(fp, "%.1e", (double)ggml_get_f32_1d(node, j));
20776
  }
20777
  else {
 
21829
  ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n);
21830
  result = n * elemsize;
21831
  } break;
21832
+ case GGML_TYPE_BF16:
21833
+ {
21834
+ size_t elemsize = sizeof(ggml_bf16_t);
21835
+ ggml_fp32_to_bf16_row(src + start, (ggml_bf16_t *)dst + start, n);
21836
+ result = n * elemsize;
21837
+ } break;
21838
  case GGML_TYPE_F32:
21839
  {
21840
  size_t elemsize = sizeof(float);
ggml.h CHANGED
@@ -326,14 +326,20 @@ extern "C" {
326
  // get ggml_status name string
327
  GGML_API GGML_CALL const char * ggml_status_to_string(enum ggml_status status);
328
 
 
 
329
  typedef uint16_t ggml_fp16_t;
330
-
331
- // convert FP16 <-> FP32
332
- GGML_API float ggml_fp16_to_fp32(ggml_fp16_t x);
333
- GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x);
334
-
335
- GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n);
336
- GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n);
 
 
 
 
337
 
338
  struct ggml_object;
339
  struct ggml_context;
@@ -370,6 +376,7 @@ extern "C" {
370
  GGML_TYPE_I64 = 27,
371
  GGML_TYPE_F64 = 28,
372
  GGML_TYPE_IQ1_M = 29,
 
373
  GGML_TYPE_COUNT,
374
  };
375
 
@@ -410,6 +417,7 @@ extern "C" {
410
  GGML_FTYPE_MOSTLY_IQ2_S = 21, // except 1d tensors
411
  GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
412
  GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
 
413
  };
414
 
415
  // available tensor operations:
 
326
  // get ggml_status name string
327
  GGML_API GGML_CALL const char * ggml_status_to_string(enum ggml_status status);
328
 
329
+ // ieee 754-2008 half-precision float16
330
+ // todo: make this not an integral type
331
  typedef uint16_t ggml_fp16_t;
332
+ GGML_API float ggml_fp16_to_fp32(ggml_fp16_t);
333
+ GGML_API ggml_fp16_t ggml_fp32_to_fp16(float);
334
+ GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t *, float *, int64_t);
335
+ GGML_API void ggml_fp32_to_fp16_row(const float *, ggml_fp16_t *, int64_t);
336
+
337
+ // google brain half-precision bfloat16
338
+ typedef struct { uint16_t bits; } ggml_bf16_t;
339
+ GGML_API ggml_bf16_t ggml_fp32_to_bf16(float);
340
+ GGML_API float ggml_bf16_to_fp32(ggml_bf16_t); // consider just doing << 16
341
+ GGML_API void ggml_bf16_to_fp32_row(const ggml_bf16_t *, float *, int64_t);
342
+ GGML_API void ggml_fp32_to_bf16_row(const float *, ggml_bf16_t *, int64_t);
343
 
344
  struct ggml_object;
345
  struct ggml_context;
 
376
  GGML_TYPE_I64 = 27,
377
  GGML_TYPE_F64 = 28,
378
  GGML_TYPE_IQ1_M = 29,
379
+ GGML_TYPE_BF16 = 30,
380
  GGML_TYPE_COUNT,
381
  };
382
 
 
417
  GGML_FTYPE_MOSTLY_IQ2_S = 21, // except 1d tensors
418
  GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
419
  GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
420
+ GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
421
  };
422
 
423
  // available tensor operations: