ggerganov commited on
Commit
8dfec0c
·
unverified ·
1 Parent(s): 20ca90d

whisper : avoid some memory allocations

Browse files
Files changed (1) hide show
  1. whisper.cpp +21 -6
whisper.cpp CHANGED
@@ -204,6 +204,10 @@ struct whisper_vocab {
204
  std::map<token, id> token_to_id;
205
  std::map<id, token> id_to_token;
206
 
 
 
 
 
207
  id token_eot = 50256;
208
  id token_sot = 50257;
209
  id token_prev = 50360;
@@ -551,6 +555,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
551
 
552
  std::string word;
553
  std::vector<char> tmp;
 
 
 
554
  for (int i = 0; i < n_vocab; i++) {
555
  uint32_t len;
556
  read_safe(fin, len);
@@ -603,6 +610,11 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
603
  vocab.id_to_token[i] = word;
604
  }
605
  }
 
 
 
 
 
606
  }
607
 
608
  {
@@ -1021,7 +1033,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
1021
 
1022
  std::string name;
1023
  std::vector<char> tmp(length); // create a buffer
1024
- fin.read( &tmp[0], tmp.size() ); // read to buffer
1025
  name.assign(&tmp[0], tmp.size());
1026
 
1027
  if (model.tensors.find(name) == model.tensors.end()) {
@@ -1849,7 +1861,7 @@ static bool whisper_decode(
1849
 
1850
  // the most basic sampling scheme - select the top token
1851
  static whisper_token_data whisper_sample_best(
1852
- const whisper_vocab & vocab,
1853
  const float * probs,
1854
  bool force_timestamp,
1855
  bool is_initial) {
@@ -1857,11 +1869,11 @@ static whisper_token_data whisper_sample_best(
1857
  0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
1858
  };
1859
 
1860
- int n_logits = vocab.id_to_token.size();
1861
 
1862
- std::vector<std::pair<double, whisper_vocab::id>> probs_id;
1863
- probs_id.reserve(n_logits);
1864
 
 
1865
  for (int i = 0; i < n_logits; i++) {
1866
  probs_id.emplace_back(probs[i], i);
1867
  }
@@ -2001,6 +2013,9 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
2001
  std::vector<float> even;
2002
  std::vector<float> odd;
2003
 
 
 
 
2004
  for (int i = 0; i < N; i++) {
2005
  if (i % 2 == 0) {
2006
  even.push_back(in[i]);
@@ -2434,7 +2449,7 @@ int whisper_lang_auto_detect(
2434
  std::vector<std::pair<float, int>> probs_id;
2435
  for (const auto & kv : g_lang) {
2436
  const auto token_lang = whisper_token_lang(ctx, kv.second.first);
2437
- probs_id.emplace_back( ctx->probs[token_lang], kv.second.first );
2438
  }
2439
 
2440
  // sort descending
 
204
  std::map<token, id> token_to_id;
205
  std::map<id, token> id_to_token;
206
 
207
+ // used to avoid memory allocations during sampling
208
+ // TODO: move to whisper_context in the future
209
+ std::vector<std::pair<double, whisper_vocab::id>> probs_id;
210
+
211
  id token_eot = 50256;
212
  id token_sot = 50257;
213
  id token_prev = 50360;
 
555
 
556
  std::string word;
557
  std::vector<char> tmp;
558
+
559
+ tmp.reserve(128);
560
+
561
  for (int i = 0; i < n_vocab; i++) {
562
  uint32_t len;
563
  read_safe(fin, len);
 
610
  vocab.id_to_token[i] = word;
611
  }
612
  }
613
+
614
+ wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
615
+ wctx.probs.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
616
+
617
+ vocab.probs_id.reserve(n_vocab);
618
  }
619
 
620
  {
 
1033
 
1034
  std::string name;
1035
  std::vector<char> tmp(length); // create a buffer
1036
+ fin.read(&tmp[0], tmp.size()); // read to buffer
1037
  name.assign(&tmp[0], tmp.size());
1038
 
1039
  if (model.tensors.find(name) == model.tensors.end()) {
 
1861
 
1862
  // the most basic sampling scheme - select the top token
1863
  static whisper_token_data whisper_sample_best(
1864
+ whisper_vocab & vocab,
1865
  const float * probs,
1866
  bool force_timestamp,
1867
  bool is_initial) {
 
1869
  0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
1870
  };
1871
 
1872
+ const int n_logits = vocab.n_vocab;
1873
 
1874
+ auto & probs_id = vocab.probs_id;
 
1875
 
1876
+ probs_id.clear();
1877
  for (int i = 0; i < n_logits; i++) {
1878
  probs_id.emplace_back(probs[i], i);
1879
  }
 
2013
  std::vector<float> even;
2014
  std::vector<float> odd;
2015
 
2016
+ even.reserve(N/2);
2017
+ odd.reserve(N/2);
2018
+
2019
  for (int i = 0; i < N; i++) {
2020
  if (i % 2 == 0) {
2021
  even.push_back(in[i]);
 
2449
  std::vector<std::pair<float, int>> probs_id;
2450
  for (const auto & kv : g_lang) {
2451
  const auto token_lang = whisper_token_lang(ctx, kv.second.first);
2452
+ probs_id.emplace_back(ctx->probs[token_lang], kv.second.first);
2453
  }
2454
 
2455
  // sort descending