ggerganov commited on
Commit
dabc473
·
1 Parent(s): c0943fb

Extend C-style API with full inference methods

Browse files
Files changed (4) hide show
  1. main.cpp +22 -166
  2. stream.cpp +25 -167
  3. whisper.cpp +260 -6
  4. whisper.h +21 -9
main.cpp CHANGED
@@ -5,17 +5,11 @@
5
  #define DR_WAV_IMPLEMENTATION
6
  #include "dr_wav.h"
7
 
8
- #include <cassert>
9
  #include <cstdio>
10
  #include <string>
11
  #include <thread>
12
  #include <vector>
13
 
14
- int64_t get_time_us() {
15
- return std::chrono::duration_cast<std::chrono::microseconds>(
16
- std::chrono::high_resolution_clock::now().time_since_epoch()).count();
17
- }
18
-
19
  // 500 -> 00:05.000
20
  // 6000 -> 01:00.000
21
  std::string to_timestamp(int64_t t) {
@@ -30,11 +24,6 @@ std::string to_timestamp(int64_t t) {
30
  return std::string(buf);
31
  }
32
 
33
- struct whisper_result {
34
- whisper_token id;
35
- int64_t t;
36
- };
37
-
38
  // command-line parameters
39
  struct whisper_params {
40
  int32_t seed = -1; // RNG seed, not used currently
@@ -111,8 +100,6 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
111
  }
112
 
113
  int main(int argc, char ** argv) {
114
- const int64_t t_main_start_us = get_time_us();
115
-
116
  whisper_params params;
117
 
118
  if (whisper_params_parse(argc, argv, params) == false) {
@@ -142,7 +129,7 @@ int main(int argc, char ** argv) {
142
  return 3;
143
  }
144
 
145
- if (wav.sampleRate != SAMPLE_RATE) {
146
  fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], params.fname_inp.c_str());
147
  return 4;
148
  }
@@ -172,12 +159,6 @@ int main(int argc, char ** argv) {
172
  }
173
  }
174
 
175
- // compute log mel spectrogram
176
- if (whisper_pcm_to_mel(ctx, pcmf32.data(), pcmf32.size(), params.n_threads) != 0) {
177
- fprintf(stderr, "%s: failed to compute log mel spectrogram\n", argv[0]);
178
- return 6;
179
- }
180
-
181
  // print some info about the processing
182
  {
183
  printf("\n");
@@ -189,168 +170,43 @@ int main(int argc, char ** argv) {
189
  }
190
  }
191
  printf("%s: processing %d samples (%.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n",
192
- __func__, int(pcmf32.size()), float(pcmf32.size())/SAMPLE_RATE, params.n_threads,
193
  params.language.c_str(),
194
  params.translate ? "translate" : "transcribe",
195
  params.no_timestamps ? 0 : 1);
196
  printf("\n");
197
  }
198
 
199
- // the accumulated text context so far
200
- std::vector<whisper_token> prompt_past = { };
201
-
202
- // these tokens determine the task that will be performed
203
- std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
204
- if (whisper_is_multilingual(ctx)) {
205
- prompt_init.push_back(whisper_token_sot(ctx) + 1 + whisper_lang_id(params.language.c_str()));
206
- if (params.translate) {
207
- prompt_init.push_back(whisper_token_translate());
208
- } else {
209
- prompt_init.push_back(whisper_token_transcribe());
210
- }
211
- }
212
-
213
- // the generated text including timestamps
214
- //std::vector<whisper_result> result_all;
215
-
216
- // main loop
217
- int seek = 0;
218
- while (true) {
219
- if (seek >= whisper_n_len(ctx)) {
220
- break;
221
- }
222
-
223
- // encode audio features starting at offset seek
224
- if (whisper_encode(ctx, seek, params.n_threads) != 0) {
225
- fprintf(stderr, "%s: failed to encode\n", __func__);
226
- return 7;
227
- }
228
-
229
- std::vector<whisper_token> prompt;
230
-
231
- int n_past = 0;
232
-
233
- // if we have already generated some text, use it as a prompt to condition the next generation
234
- if (prompt_past.size() > 0) {
235
- int n_take = std::min(whisper_n_text_ctx(ctx)/2, int(prompt_past.size()));
236
-
237
- prompt = { whisper_token_prev(ctx) };
238
- prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
239
-
240
- prompt_past.clear();
241
- prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end());
242
- }
243
-
244
- prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
245
-
246
- bool done = false;
247
- int seek_delta = 100*CHUNK_SIZE;
248
- whisper_token last_id = 0;
249
-
250
- // print the prompt
251
- //printf("\n\n");
252
- //for (int i = 0; i < prompt.size(); i++) {
253
- // printf("%s: prompt[%d] = %s\n", __func__, i, vocab.id_to_token[prompt[i]].c_str());
254
- //}
255
- //printf("\n\n");
256
-
257
- // the accumulated transcription in the current interation
258
- int result_len = 0;
259
- std::vector<whisper_result> result_cur;
260
-
261
- for (int i = 0; i < whisper_n_text_ctx(ctx)/2 - 4; ++i) {
262
- if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
263
- fprintf(stderr, "%s: failed to decode\n", __func__);
264
- return 8;
265
- }
266
-
267
- n_past += prompt.size();
268
- prompt.clear();
269
-
270
- // very basic greedy sampling strategy:
271
- //
272
- // - always take the most probable token
273
- //
274
- // more sophisticated sampling strategies could be implemented here, but we keep it simple
275
- // feel free to experiment!
276
- //
277
- {
278
- const int n_vocab = whisper_n_vocab(ctx);
279
-
280
- whisper_token id = 0;
281
- whisper_token tid = whisper_token_beg(ctx);
282
-
283
- id = whisper_sample_best(ctx, result_len == 0);
284
- if (i > 0) {
285
- tid = whisper_sample_timestamp(ctx);
286
- }
287
-
288
- // update sliding window
289
- if (id > whisper_token_beg(ctx)) {
290
- seek_delta = 2*(id - whisper_token_beg(ctx));
291
- result_len = i + 1;
292
- }
293
- last_id = id;
294
-
295
- // add it to the context
296
- prompt.push_back(id);
297
- result_cur.push_back({ id, seek + 2*(tid - whisper_token_beg(ctx)) });
298
-
299
- //printf("%s: %s\n", __func__, vocab.id_to_token[id].c_str());
300
 
301
- // end of text token
302
- if (id == whisper_token_eot(ctx)) {
303
- break;
304
- }
305
- }
306
 
307
- if (done) {
308
- break;
309
- }
310
  }
311
 
312
- result_cur.resize(result_len);
313
- //result_all.insert(result_all.end(), result_cur.begin(), result_cur.end());
314
-
315
- for (const auto & r : result_cur) {
316
- prompt_past.push_back(r.id);
317
- }
318
 
319
- // print the text from this iteration
320
- if (result_cur.size() > 0) {
321
- auto t0 = result_cur.front().t;
322
 
323
- std::string text = "";
324
- for (int i = 0; i < result_cur.size(); i++) {
325
- if (params.print_special_tokens == false && result_cur[i].id >= whisper_token_eot(ctx)) {
326
  } else {
327
- text += whisper_token_to_str(ctx, result_cur[i].id);
328
- }
329
- if (result_cur[i].id > whisper_token_beg(ctx)) {
330
- const auto t1 = result_cur[i].t;
331
- if (!text.empty()) {
332
- if (params.no_timestamps) {
333
- printf ("%s", text.c_str());
334
- fflush(stdout);
335
- } else {
336
- printf ("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text.c_str());
337
- }
338
- }
339
- text = "";
340
- while (result_cur[i].id > whisper_token_beg(ctx) && i < result_cur.size()) {
341
- i++;
342
- }
343
- i--;
344
- t0 = result_cur[i].t;
345
- }
346
- }
347
 
348
- if (!text.empty()) {
349
- printf ("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(seek + seek_delta).c_str(), text.c_str());
350
  }
351
  }
352
-
353
- seek += seek_delta;
354
  }
355
 
356
  whisper_print_timings(ctx);
 
5
  #define DR_WAV_IMPLEMENTATION
6
  #include "dr_wav.h"
7
 
 
8
  #include <cstdio>
9
  #include <string>
10
  #include <thread>
11
  #include <vector>
12
 
 
 
 
 
 
13
  // 500 -> 00:05.000
14
  // 6000 -> 01:00.000
15
  std::string to_timestamp(int64_t t) {
 
24
  return std::string(buf);
25
  }
26
 
 
 
 
 
 
27
  // command-line parameters
28
  struct whisper_params {
29
  int32_t seed = -1; // RNG seed, not used currently
 
100
  }
101
 
102
  int main(int argc, char ** argv) {
 
 
103
  whisper_params params;
104
 
105
  if (whisper_params_parse(argc, argv, params) == false) {
 
129
  return 3;
130
  }
131
 
132
+ if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
133
  fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], params.fname_inp.c_str());
134
  return 4;
135
  }
 
159
  }
160
  }
161
 
 
 
 
 
 
 
162
  // print some info about the processing
163
  {
164
  printf("\n");
 
170
  }
171
  }
172
  printf("%s: processing %d samples (%.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n",
173
+ __func__, int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, params.n_threads,
174
  params.language.c_str(),
175
  params.translate ? "translate" : "transcribe",
176
  params.no_timestamps ? 0 : 1);
177
  printf("\n");
178
  }
179
 
180
+ // run the inference
181
+ {
182
+ whisper_full_params wparams = whisper_full_default_params(WHISPER_DECODE_GREEDY);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
+ wparams.print_special_tokens = params.print_special_tokens;
 
 
 
 
185
 
186
+ if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
187
+ fprintf(stderr, "%s: failed to process audio\n", argv[0]);
188
+ return 6;
189
  }
190
 
191
+ // print result;
192
+ {
193
+ printf("\n");
 
 
 
194
 
195
+ const int n_segments = whisper_full_n_segments(ctx);
196
+ for (int i = 0; i < n_segments; ++i) {
197
+ const char * text = whisper_full_get_segment_text(ctx, i);
198
 
199
+ if (params.no_timestamps) {
200
+ printf ("%s", text);
201
+ fflush(stdout);
202
  } else {
203
+ const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
204
+ const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
+ printf ("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
207
+ }
208
  }
209
  }
 
 
210
  }
211
 
212
  whisper_print_timings(ctx);
stream.cpp CHANGED
@@ -18,11 +18,6 @@
18
  #include <thread>
19
  #include <vector>
20
 
21
- int64_t get_time_us() {
22
- return std::chrono::duration_cast<std::chrono::microseconds>(
23
- std::chrono::high_resolution_clock::now().time_since_epoch()).count();
24
- }
25
-
26
  // 500 -> 00:05.000
27
  // 6000 -> 01:00.000
28
  std::string to_timestamp(int64_t t) {
@@ -37,11 +32,6 @@ std::string to_timestamp(int64_t t) {
37
  return std::string(buf);
38
  }
39
 
40
- struct whisper_result {
41
- whisper_token id;
42
- int64_t t;
43
- };
44
-
45
  // command-line parameters
46
  struct whisper_params {
47
  int32_t seed = -1; // RNG seed, not used currently
@@ -155,7 +145,7 @@ bool audio_sdl_init(const int capture_id) {
155
  SDL_zero(capture_spec_requested);
156
  SDL_zero(capture_spec_obtained);
157
 
158
- capture_spec_requested.freq = SAMPLE_RATE;
159
  capture_spec_requested.format = AUDIO_F32;
160
  capture_spec_requested.channels = 1;
161
  capture_spec_requested.samples = 1024;
@@ -186,8 +176,6 @@ bool audio_sdl_init(const int capture_id) {
186
  ///////////////////////////
187
 
188
  int main(int argc, char ** argv) {
189
- const int64_t t_main_start_us = get_time_us();
190
-
191
  whisper_params params;
192
 
193
  if (whisper_params_parse(argc, argv, params) == false) {
@@ -209,7 +197,7 @@ int main(int argc, char ** argv) {
209
 
210
  struct whisper_context * ctx = whisper_init(params.model.c_str());
211
 
212
- const int n_samples_30s = 30*SAMPLE_RATE;
213
  std::vector<float> pcmf32(n_samples_30s, 0.0f);
214
  std::vector<float> pcmf32_old;
215
 
@@ -224,7 +212,7 @@ int main(int argc, char ** argv) {
224
  }
225
  }
226
  printf("%s: processing %d samples (%.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n",
227
- __func__, int(pcmf32.size()), float(pcmf32.size())/SAMPLE_RATE, params.n_threads,
228
  params.language.c_str(),
229
  params.translate ? "translate" : "transcribe",
230
  params.no_timestamps ? 0 : 1);
@@ -250,7 +238,7 @@ int main(int argc, char ** argv) {
250
  }
251
 
252
  // process 3 seconds of new audio
253
- while ((int) SDL_GetQueuedAudioSize(g_dev_id_in) < 3*SAMPLE_RATE*sizeof(float)) {
254
  SDL_Delay(1);
255
  }
256
  const int n_samples_new = SDL_GetQueuedAudioSize(g_dev_id_in)/sizeof(float);
@@ -271,167 +259,37 @@ int main(int argc, char ** argv) {
271
 
272
  pcmf32_old = pcmf32;
273
 
274
- // compute log mel spectrogram
275
- if (whisper_pcm_to_mel(ctx, pcmf32.data(), pcmf32.size(), params.n_threads) != 0) {
276
- fprintf(stderr, "%s: failed to compute log mel spectrogram\n", argv[0]);
277
- return 6;
278
- }
279
-
280
- // the accumulated text context so far
281
- std::vector<whisper_token> prompt_past = { };
282
-
283
- // these tokens determine the task that will be performed
284
- std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
285
- if (whisper_is_multilingual(ctx)) {
286
- prompt_init.push_back(whisper_token_sot(ctx) + 1 + whisper_lang_id(params.language.c_str()));
287
- if (params.translate) {
288
- prompt_init.push_back(whisper_token_translate());
289
- } else {
290
- prompt_init.push_back(whisper_token_transcribe());
291
- }
292
- }
293
-
294
- // the generated text including timestamps
295
- //std::vector<whisper_result> result_all;
296
-
297
- // main loop
298
- int seek = 0;
299
- while (true) {
300
- if (seek >= whisper_n_len(ctx)) {
301
- break;
302
- }
303
-
304
- // encode audio features starting at offset seek
305
- if (whisper_encode(ctx, seek, params.n_threads) != 0) {
306
- fprintf(stderr, "%s: failed to encode\n", __func__);
307
- return 7;
308
- }
309
-
310
- std::vector<whisper_token> prompt;
311
-
312
- int n_past = 0;
313
-
314
- // if we have already generated some text, use it as a prompt to condition the next generation
315
- if (prompt_past.size() > 0) {
316
- int n_take = std::min(whisper_n_text_ctx(ctx)/2, int(prompt_past.size()));
317
-
318
- prompt = { whisper_token_prev(ctx) };
319
- prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
320
-
321
- prompt_past.clear();
322
- prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end());
323
- }
324
-
325
- prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
326
-
327
- bool done = false;
328
- int seek_delta = 100*CHUNK_SIZE;
329
- whisper_token last_id = 0;
330
-
331
- // print the prompt
332
- //printf("\n\n");
333
- //for (int i = 0; i < prompt.size(); i++) {
334
- // printf("%s: prompt[%d] = %s\n", __func__, i, vocab.id_to_token[prompt[i]].c_str());
335
- //}
336
- //printf("\n\n");
337
-
338
- // the accumulated transcription in the current interation
339
- int result_len = 0;
340
- std::vector<whisper_result> result_cur;
341
-
342
- for (int i = 0; i < whisper_n_text_ctx(ctx)/2 - 4; ++i) {
343
- if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
344
- fprintf(stderr, "%s: failed to decode\n", __func__);
345
- return 8;
346
- }
347
-
348
- n_past += prompt.size();
349
- prompt.clear();
350
-
351
- // very basic greedy sampling strategy:
352
- //
353
- // - always take the most probable token
354
- //
355
- // more sophisticated sampling strategies could be implemented here, but we keep it simple
356
- // feel free to experiment!
357
- //
358
- {
359
- const int n_vocab = whisper_n_vocab(ctx);
360
-
361
- whisper_token id = 0;
362
- whisper_token tid = whisper_token_beg(ctx);
363
-
364
- id = whisper_sample_best(ctx, result_len == 0);
365
- if (i > 0) {
366
- tid = whisper_sample_timestamp(ctx);
367
- }
368
-
369
- // update sliding window
370
- if (id > whisper_token_beg(ctx)) {
371
- seek_delta = 2*(id - whisper_token_beg(ctx));
372
- result_len = i + 1;
373
- }
374
- last_id = id;
375
-
376
- // add it to the context
377
- prompt.push_back(id);
378
- result_cur.push_back({ id, seek + 2*(tid - whisper_token_beg(ctx)) });
379
-
380
- //printf("%s: %s\n", __func__, vocab.id_to_token[id].c_str());
381
 
382
- // end of text token
383
- if (id == whisper_token_eot(ctx)) {
384
- break;
385
- }
386
- }
387
 
388
- if (done) {
389
- break;
390
- }
391
  }
392
 
393
- result_cur.resize(result_len);
394
- //result_all.insert(result_all.end(), result_cur.begin(), result_cur.end());
395
-
396
- for (const auto & r : result_cur) {
397
- prompt_past.push_back(r.id);
398
- }
399
 
400
- // print the text from this iteration
401
- if (result_cur.size() > 0) {
402
- auto t0 = result_cur.front().t;
403
 
404
- std::string text = "";
405
- for (int i = 0; i < result_cur.size(); i++) {
406
- if (params.print_special_tokens == false && result_cur[i].id >= whisper_token_eot(ctx)) {
407
  } else {
408
- text += whisper_token_to_str(ctx, result_cur[i].id);
409
- }
410
- if (result_cur[i].id > whisper_token_beg(ctx)) {
411
- const auto t1 = result_cur[i].t;
412
- if (!text.empty()) {
413
- if (params.no_timestamps) {
414
- printf ("%s", text.c_str());
415
- fflush(stdout);
416
- } else {
417
- printf ("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text.c_str());
418
- }
419
- }
420
- text = "";
421
- while (result_cur[i].id > whisper_token_beg(ctx) && i < result_cur.size()) {
422
- i++;
423
- }
424
- i--;
425
- t0 = result_cur[i].t;
426
- }
427
- }
428
 
429
- if (!text.empty()) {
430
- printf ("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(seek + seek_delta).c_str(), text.c_str());
431
  }
432
  }
433
-
434
- seek += seek_delta;
435
  }
436
  }
437
 
 
18
  #include <thread>
19
  #include <vector>
20
 
 
 
 
 
 
21
  // 500 -> 00:05.000
22
  // 6000 -> 01:00.000
23
  std::string to_timestamp(int64_t t) {
 
32
  return std::string(buf);
33
  }
34
 
 
 
 
 
 
35
  // command-line parameters
36
  struct whisper_params {
37
  int32_t seed = -1; // RNG seed, not used currently
 
145
  SDL_zero(capture_spec_requested);
146
  SDL_zero(capture_spec_obtained);
147
 
148
+ capture_spec_requested.freq = WHISPER_SAMPLE_RATE;
149
  capture_spec_requested.format = AUDIO_F32;
150
  capture_spec_requested.channels = 1;
151
  capture_spec_requested.samples = 1024;
 
176
  ///////////////////////////
177
 
178
  int main(int argc, char ** argv) {
 
 
179
  whisper_params params;
180
 
181
  if (whisper_params_parse(argc, argv, params) == false) {
 
197
 
198
  struct whisper_context * ctx = whisper_init(params.model.c_str());
199
 
200
+ const int n_samples_30s = 30*WHISPER_SAMPLE_RATE;
201
  std::vector<float> pcmf32(n_samples_30s, 0.0f);
202
  std::vector<float> pcmf32_old;
203
 
 
212
  }
213
  }
214
  printf("%s: processing %d samples (%.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n",
215
+ __func__, int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, params.n_threads,
216
  params.language.c_str(),
217
  params.translate ? "translate" : "transcribe",
218
  params.no_timestamps ? 0 : 1);
 
238
  }
239
 
240
  // process 3 seconds of new audio
241
+ while ((int) SDL_GetQueuedAudioSize(g_dev_id_in) < 3*WHISPER_SAMPLE_RATE*sizeof(float)) {
242
  SDL_Delay(1);
243
  }
244
  const int n_samples_new = SDL_GetQueuedAudioSize(g_dev_id_in)/sizeof(float);
 
259
 
260
  pcmf32_old = pcmf32;
261
 
262
+ // run the inference
263
+ {
264
+ whisper_full_params wparams = whisper_full_default_params(WHISPER_DECODE_GREEDY);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
 
266
+ wparams.print_progress = false;
267
+ wparams.print_special_tokens = params.print_special_tokens;
 
 
 
268
 
269
+ if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
270
+ fprintf(stderr, "%s: failed to process audio\n", argv[0]);
271
+ return 6;
272
  }
273
 
274
+ // print result;
275
+ {
276
+ printf("\n");
 
 
 
277
 
278
+ const int n_segments = whisper_full_n_segments(ctx);
279
+ for (int i = 0; i < n_segments; ++i) {
280
+ const char * text = whisper_full_get_segment_text(ctx, i);
281
 
282
+ if (params.no_timestamps) {
283
+ printf ("%s", text);
284
+ fflush(stdout);
285
  } else {
286
+ const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
287
+ const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
+ printf ("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
290
+ }
291
  }
292
  }
 
 
293
  }
294
  }
295
 
whisper.cpp CHANGED
@@ -210,8 +210,15 @@ struct whisper_vocab {
210
  };
211
 
212
  struct whisper_result {
213
- whisper_vocab::id id;
214
  int64_t t;
 
 
 
 
 
 
 
 
215
  };
216
 
217
  // medium
@@ -395,6 +402,9 @@ struct whisper_context {
395
 
396
  std::vector<float> probs;
397
  std::vector<float> logits;
 
 
 
398
  };
399
 
400
  // load the model from a ggml file
@@ -1946,8 +1956,8 @@ bool log_mel_spectrogram(
1946
 
1947
  const int n_fft = 1 + fft_size/2;
1948
 
1949
- printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
1950
- printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
1951
 
1952
  std::vector<std::thread> workers(n_threads);
1953
  for (int iw = 0; iw < n_threads; ++iw) {
@@ -2066,7 +2076,7 @@ void whisper_free(struct whisper_context * ctx) {
2066
  int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
2067
  const int64_t t_start_us = ggml_time_us();
2068
 
2069
- if (!log_mel_spectrogram(samples, n_samples, SAMPLE_RATE, N_FFT, HOP_LENGTH, N_MEL, n_threads, ctx->model.filters, ctx->mel)) {
2070
  fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
2071
  return -1;
2072
  }
@@ -2081,8 +2091,8 @@ int whisper_set_mel(
2081
  const float * data,
2082
  int n_len,
2083
  int n_mel) {
2084
- if (n_mel != N_MEL) {
2085
- fprintf(stderr, "%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, N_MEL);
2086
  return -1;
2087
  }
2088
 
@@ -2219,3 +2229,247 @@ void whisper_print_timings(struct whisper_context * ctx) {
2219
  printf("%s: decode time = %8.2f ms / %.2f ms per layer\n", __func__, ctx->t_decode_us/1000.0f, ctx->t_decode_us/1000.0f/ctx->model.hparams.n_text_layer);
2220
  printf("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
2221
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  };
211
 
212
  struct whisper_result {
 
213
  int64_t t;
214
+ whisper_token id;
215
+ };
216
+
217
+ struct whisper_segment {
218
+ int64_t t0;
219
+ int64_t t1;
220
+
221
+ std::string text;
222
  };
223
 
224
  // medium
 
402
 
403
  std::vector<float> probs;
404
  std::vector<float> logits;
405
+
406
+ std::vector<whisper_result> result_cur;
407
+ std::vector<whisper_segment> result_all;
408
  };
409
 
410
  // load the model from a ggml file
 
1956
 
1957
  const int n_fft = 1 + fft_size/2;
1958
 
1959
+ //printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
1960
+ //printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
1961
 
1962
  std::vector<std::thread> workers(n_threads);
1963
  for (int iw = 0; iw < n_threads; ++iw) {
 
2076
  int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
2077
  const int64_t t_start_us = ggml_time_us();
2078
 
2079
+ if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, ctx->mel)) {
2080
  fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
2081
  return -1;
2082
  }
 
2091
  const float * data,
2092
  int n_len,
2093
  int n_mel) {
2094
+ if (n_mel != WHISPER_N_MEL) {
2095
+ fprintf(stderr, "%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL);
2096
  return -1;
2097
  }
2098
 
 
2229
  printf("%s: decode time = %8.2f ms / %.2f ms per layer\n", __func__, ctx->t_decode_us/1000.0f, ctx->t_decode_us/1000.0f/ctx->model.hparams.n_text_layer);
2230
  printf("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
2231
  }
2232
+
2233
+ ////////////////////////////////////////////////////////////////////////////
2234
+
2235
+ struct whisper_full_params whisper_full_default_params(enum whisper_decode_strategy strategy) {
2236
+ struct whisper_full_params result;
2237
+
2238
+ switch (strategy) {
2239
+ case WHISPER_DECODE_GREEDY:
2240
+ {
2241
+ result = (struct whisper_full_params) {
2242
+ .strategy = WHISPER_DECODE_GREEDY,
2243
+ .n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()),
2244
+
2245
+ .translate = false,
2246
+ .print_special_tokens = false,
2247
+ .print_progress = true,
2248
+
2249
+ .language = "en",
2250
+
2251
+ .greedy = {
2252
+ .n_past = 0,
2253
+ },
2254
+ };
2255
+ } break;
2256
+ case WHISPER_DECODE_BEAM_SEARCH:
2257
+ {
2258
+ result = (struct whisper_full_params) {
2259
+ .strategy = WHISPER_DECODE_GREEDY,
2260
+ .n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()),
2261
+
2262
+ .translate = false,
2263
+ .print_special_tokens = false,
2264
+ .print_progress = true,
2265
+
2266
+ .language = "en",
2267
+
2268
+ .beam_search = {
2269
+ .n_past = 0,
2270
+ .beam_width = 10,
2271
+ .n_best = 5,
2272
+ },
2273
+ };
2274
+ } break;
2275
+ }
2276
+
2277
+ return result;
2278
+ }
2279
+ int whisper_full(
2280
+ struct whisper_context * ctx,
2281
+ struct whisper_full_params params,
2282
+ const float * samples,
2283
+ int n_samples) {
2284
+ // compute log mel spectrogram
2285
+ if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
2286
+ fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
2287
+ return -1;
2288
+ }
2289
+
2290
+ // the accumulated text context so far
2291
+ std::vector<whisper_token> prompt_past = { };
2292
+
2293
+ // these tokens determine the task that will be performed
2294
+ std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
2295
+ if (whisper_is_multilingual(ctx)) {
2296
+ prompt_init.push_back(whisper_token_sot(ctx) + 1 + whisper_lang_id(params.language));
2297
+ if (params.translate) {
2298
+ prompt_init.push_back(whisper_token_translate());
2299
+ } else {
2300
+ prompt_init.push_back(whisper_token_transcribe());
2301
+ }
2302
+ }
2303
+
2304
+ auto & result_all = ctx->result_all;
2305
+ auto & result_cur = ctx->result_cur;
2306
+
2307
+ result_all.clear();
2308
+
2309
+ int progress_prev = 0;
2310
+ int progress_step = 5;
2311
+
2312
+ // main loop
2313
+ int seek = 0;
2314
+ while (true) {
2315
+ int progress_cur = (100*seek)/whisper_n_len(ctx);
2316
+ while (progress_cur >= progress_prev + progress_step) {
2317
+ progress_prev += progress_step;
2318
+ if (params.print_progress) {
2319
+ printf("%s: progress = %3d%%\n", __func__, progress_prev);
2320
+ }
2321
+ }
2322
+
2323
+ if (seek >= whisper_n_len(ctx)) {
2324
+ break;
2325
+ }
2326
+
2327
+ // encode audio features starting at offset seek
2328
+ if (whisper_encode(ctx, seek, params.n_threads) != 0) {
2329
+ fprintf(stderr, "%s: failed to encode\n", __func__);
2330
+ return 7;
2331
+ }
2332
+
2333
+ std::vector<whisper_token> prompt;
2334
+
2335
+ int n_past = 0;
2336
+
2337
+ // if we have already generated some text, use it as a prompt to condition the next generation
2338
+ if (prompt_past.size() > 0) {
2339
+ int n_take = std::min(whisper_n_text_ctx(ctx)/2, int(prompt_past.size()));
2340
+
2341
+ prompt = { whisper_token_prev(ctx) };
2342
+ prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
2343
+
2344
+ prompt_past.clear();
2345
+ prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end());
2346
+ }
2347
+
2348
+ prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
2349
+
2350
+ bool done = false;
2351
+ int seek_delta = 100*WHISPER_CHUNK_SIZE;
2352
+ whisper_token last_id = 0;
2353
+
2354
+ // print the prompt
2355
+ //printf("\n\n");
2356
+ //for (int i = 0; i < prompt.size(); i++) {
2357
+ // printf("%s: prompt[%d] = %s\n", __func__, i, vocab.id_to_token[prompt[i]].c_str());
2358
+ //}
2359
+ //printf("\n\n");
2360
+
2361
+ // the accumulated transcription in the current interation
2362
+ int result_len = 0;
2363
+ result_cur.clear();
2364
+
2365
+ for (int i = 0; i < whisper_n_text_ctx(ctx)/2 - 4; ++i) {
2366
+ if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
2367
+ fprintf(stderr, "%s: failed to decode\n", __func__);
2368
+ return 8;
2369
+ }
2370
+
2371
+ n_past += prompt.size();
2372
+ prompt.clear();
2373
+
2374
+ // very basic greedy sampling strategy:
2375
+ //
2376
+ // - always take the most probable token
2377
+ //
2378
+ // more sophisticated sampling strategies could be implemented here, but we keep it simple
2379
+ // feel free to experiment!
2380
+ //
2381
+ {
2382
+ const int n_vocab = whisper_n_vocab(ctx);
2383
+
2384
+ whisper_token id = 0;
2385
+ whisper_token tid = whisper_token_beg(ctx);
2386
+
2387
+ id = whisper_sample_best(ctx, result_len == 0);
2388
+ if (i > 0) {
2389
+ tid = whisper_sample_timestamp(ctx);
2390
+ }
2391
+
2392
+ // update sliding window
2393
+ if (id > whisper_token_beg(ctx)) {
2394
+ seek_delta = 2*(id - whisper_token_beg(ctx));
2395
+ result_len = i + 1;
2396
+ }
2397
+ last_id = id;
2398
+
2399
+ // add it to the context
2400
+ prompt.push_back(id);
2401
+ result_cur.push_back({ seek + 2*(tid - whisper_token_beg(ctx)), id });
2402
+
2403
+ //printf("%s: %s\n", __func__, ctx->vocab.id_to_token[id].c_str());
2404
+
2405
+ // end of text token
2406
+ if (id == whisper_token_eot(ctx)) {
2407
+ if (result_len == 0) {
2408
+ result_len = i + 1;
2409
+ }
2410
+ break;
2411
+ }
2412
+ }
2413
+
2414
+ if (done) {
2415
+ break;
2416
+ }
2417
+ }
2418
+
2419
+ result_cur.resize(result_len);
2420
+
2421
+ for (const auto & r : result_cur) {
2422
+ prompt_past.push_back(r.id);
2423
+ }
2424
+
2425
+ // store the text from this iteration
2426
+ if (result_cur.size() > 0) {
2427
+ auto t0 = result_cur.front().t;
2428
+
2429
+ std::string text = "";
2430
+
2431
+ for (int i = 0; i < result_cur.size(); i++) {
2432
+ if (params.print_special_tokens == false && result_cur[i].id >= whisper_token_eot(ctx)) {
2433
+ } else {
2434
+ text += whisper_token_to_str(ctx, result_cur[i].id);
2435
+ }
2436
+ if (result_cur[i].id > whisper_token_beg(ctx)) {
2437
+ const auto t1 = result_cur[i].t;
2438
+ if (!text.empty()) {
2439
+ result_all.push_back({ t0, t1, text });
2440
+ }
2441
+ text = "";
2442
+ while (result_cur[i].id > whisper_token_beg(ctx) && i < result_cur.size()) {
2443
+ i++;
2444
+ }
2445
+ i--;
2446
+ t0 = result_cur[i].t;
2447
+ }
2448
+ }
2449
+
2450
+ if (!text.empty()) {
2451
+ result_all.push_back({ t0, seek + seek_delta, text });
2452
+ }
2453
+ }
2454
+
2455
+ seek += seek_delta;
2456
+ }
2457
+
2458
+ return 0;
2459
+ }
2460
+
2461
+ int whisper_full_n_segments(struct whisper_context * ctx) {
2462
+ return ctx->result_all.size();
2463
+ }
2464
+
2465
+ int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
2466
+ return ctx->result_all[i_segment].t0;
2467
+ }
2468
+
2469
+ int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) {
2470
+ return ctx->result_all[i_segment].t1;
2471
+ }
2472
+
2473
+ const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) {
2474
+ return ctx->result_all[i_segment].text.c_str();
2475
+ }
whisper.h CHANGED
@@ -1,6 +1,8 @@
1
  #ifndef WHISPER_H
2
  #define WHISPER_H
3
 
 
 
4
  #ifdef WHISPER_SHARED
5
  # ifdef _WIN32
6
  # ifdef WHISPER_BUILD
@@ -15,6 +17,12 @@
15
  # define WHISPER_API
16
  #endif
17
 
 
 
 
 
 
 
18
  #ifdef __cplusplus
19
  extern "C" {
20
  #endif
@@ -23,12 +31,6 @@ extern "C" {
23
  // C interface
24
  //
25
 
26
- #define SAMPLE_RATE 16000
27
- #define N_FFT 400
28
- #define N_MEL 80
29
- #define HOP_LENGTH 160
30
- #define CHUNK_SIZE 30
31
-
32
  // TODO: documentation will come soon
33
 
34
  struct whisper_context;
@@ -101,7 +103,9 @@ extern "C" {
101
 
102
  int n_threads;
103
 
104
- bool transcribe;
 
 
105
 
106
  const char * language;
107
 
@@ -118,14 +122,22 @@ extern "C" {
118
  };
119
  };
120
 
 
 
121
  // full whisper run - encode + decode
122
- // TODO: implement
123
  WHISPER_API int whisper_full(
124
  struct whisper_context * ctx,
125
- struct whisper_full_params * params,
126
  const float * samples,
127
  int n_samples);
128
 
 
 
 
 
 
 
 
129
  #ifdef __cplusplus
130
  }
131
  #endif
 
1
  #ifndef WHISPER_H
2
  #define WHISPER_H
3
 
4
+ #include <stdint.h>
5
+
6
  #ifdef WHISPER_SHARED
7
  # ifdef _WIN32
8
  # ifdef WHISPER_BUILD
 
17
  # define WHISPER_API
18
  #endif
19
 
20
+ #define WHISPER_SAMPLE_RATE 16000
21
+ #define WHISPER_N_FFT 400
22
+ #define WHISPER_N_MEL 80
23
+ #define WHISPER_HOP_LENGTH 160
24
+ #define WHISPER_CHUNK_SIZE 30
25
+
26
  #ifdef __cplusplus
27
  extern "C" {
28
  #endif
 
31
  // C interface
32
  //
33
 
 
 
 
 
 
 
34
  // TODO: documentation will come soon
35
 
36
  struct whisper_context;
 
103
 
104
  int n_threads;
105
 
106
+ bool translate;
107
+ bool print_special_tokens;
108
+ bool print_progress;
109
 
110
  const char * language;
111
 
 
122
  };
123
  };
124
 
125
+ WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_decode_strategy strategy);
126
+
127
  // full whisper run - encode + decode
 
128
  WHISPER_API int whisper_full(
129
  struct whisper_context * ctx,
130
+ struct whisper_full_params params,
131
  const float * samples,
132
  int n_samples);
133
 
134
+ WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx);
135
+
136
+ WHISPER_API int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment);
137
+ WHISPER_API int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment);
138
+
139
+ WHISPER_API const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment);
140
+
141
  #ifdef __cplusplus
142
  }
143
  #endif