aluminumbox commited on
Commit
eb07486
·
verified ·
1 Parent(s): fc258c9

Upload 194 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. README.md +249 -14
  3. app.py +200 -0
  4. cosyvoice/__init__.py +0 -0
  5. cosyvoice/bin/average_model.py +93 -0
  6. cosyvoice/bin/convert.py +223 -0
  7. cosyvoice/bin/export_jit.py +101 -0
  8. cosyvoice/bin/export_onnx.py +114 -0
  9. cosyvoice/bin/inference_deprecated.py +126 -0
  10. cosyvoice/bin/train.py +195 -0
  11. cosyvoice/cli/__init__.py +0 -0
  12. cosyvoice/cli/cosyvoice.py +238 -0
  13. cosyvoice/cli/frontend.py +219 -0
  14. cosyvoice/cli/model.py +430 -0
  15. cosyvoice/dataset/__init__.py +0 -0
  16. cosyvoice/dataset/dataset.py +151 -0
  17. cosyvoice/dataset/processor.py +443 -0
  18. cosyvoice/flow/DiT/dit.py +176 -0
  19. cosyvoice/flow/DiT/modules.py +616 -0
  20. cosyvoice/flow/decoder.py +494 -0
  21. cosyvoice/flow/flow.py +432 -0
  22. cosyvoice/flow/flow_matching.py +228 -0
  23. cosyvoice/flow/length_regulator.py +70 -0
  24. cosyvoice/hifigan/discriminator.py +230 -0
  25. cosyvoice/hifigan/f0_predictor.py +103 -0
  26. cosyvoice/hifigan/generator.py +746 -0
  27. cosyvoice/hifigan/hifigan.py +67 -0
  28. cosyvoice/llm/llm.py +739 -0
  29. cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +0 -0
  30. cosyvoice/tokenizer/tokenizer.py +327 -0
  31. cosyvoice/transformer/__init__.py +0 -0
  32. cosyvoice/transformer/activation.py +84 -0
  33. cosyvoice/transformer/attention.py +330 -0
  34. cosyvoice/transformer/convolution.py +258 -0
  35. cosyvoice/transformer/decoder.py +396 -0
  36. cosyvoice/transformer/decoder_layer.py +132 -0
  37. cosyvoice/transformer/embedding.py +302 -0
  38. cosyvoice/transformer/encoder.py +474 -0
  39. cosyvoice/transformer/encoder_layer.py +236 -0
  40. cosyvoice/transformer/label_smoothing_loss.py +96 -0
  41. cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  42. cosyvoice/transformer/subsampling.py +383 -0
  43. cosyvoice/transformer/upsample_encoder.py +321 -0
  44. cosyvoice/utils/__init__.py +0 -0
  45. cosyvoice/utils/class_utils.py +85 -0
  46. cosyvoice/utils/common.py +213 -0
  47. cosyvoice/utils/executor.py +176 -0
  48. cosyvoice/utils/file_utils.py +118 -0
  49. cosyvoice/utils/frontend_utils.py +136 -0
  50. cosyvoice/utils/losses.py +57 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ zero_shot_prompt.wav filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,14 +1,249 @@
1
- ---
2
- title: Fun CosyVoice3 0.5B
3
- emoji: 🚀
4
- colorFrom: green
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 6.1.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- short_description: Fun-CosyVoice3-0.5B
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [![SVG Banners](https://svg-banners.vercel.app/api?type=origin&text1=CosyVoice🤠&text2=Text-to-Speech%20💖%20Large%20Language%20Model&width=800&height=210)](https://github.com/Akshay090/svg-banners)
2
+
3
+ ## 👉🏻 CosyVoice 👈🏻
4
+
5
+ **Fun-CosyVoice 3.0**: [Demos](https://funaudiollm.github.io/cosyvoice3/); [Paper](https://arxiv.org/abs/2505.17589); [Modelscope](https://www.modelscope.cn/studios/FunAudioLLM/Fun-CosyVoice3-0.5B); [CV3-Eval](https://github.com/FunAudioLLM/CV3-Eval)
6
+
7
+ **CosyVoice 2.0**: [Demos](https://funaudiollm.github.io/cosyvoice2/); [Paper](https://arxiv.org/abs/2412.10117); [Modelscope](https://www.modelscope.cn/studios/iic/CosyVoice2-0.5B); [HuggingFace](https://huggingface.co/spaces/FunAudioLLM/CosyVoice2-0.5B)
8
+
9
+ **CosyVoice 1.0**: [Demos](https://fun-audio-llm.github.io); [Paper](https://funaudiollm.github.io/pdf/CosyVoice_v1.pdf); [Modelscope](https://www.modelscope.cn/studios/iic/CosyVoice-300M)
10
+
11
+ ## Highlight🔥
12
+
13
+ **Fun-CosyVoice 3.0** is an advanced text-to-speech (TTS) system based on large language models (LLM), surpassing its predecessor (CosyVoice 2.0) in content consistency, speaker similarity, and prosody naturalness. It is designed for zero-shot multilingual speech synthesis in the wild.
14
+ ### Key Features
15
+ - **Language Coverage**: Covers 9 common languages (Chinese, English, Japanese, Korean, German, Spanish, French, Italian, Russian), 18+ Chinese dialects/accents (Guangdong, Minnan, Sichuan, Dongbei, Shan3xi, Shan1xi, Shanghai, Tianjin, Shan1dong, Ningxia, Gansu, etc.) and meanwhile supports both multi-lingual/cross-lingual zero-shot voice cloning.
16
+ - **Content Consistency & Naturalness**: Achieves state-of-the-art performance in content consistency, speaker similarity, and prosody naturalness.
17
+ - **Pronunciation Inpainting**: Supports pronunciation inpainting of Chinese Pinyin and English CMU phonemes, providing more controllability and thus suitable for production use.
18
+ - **Text Normalization**: Supports reading of numbers, special symbols and various text formats without a traditional frontend module.
19
+ - **Bi-Streaming**: Support both text-in streaming and audio-out streaming, and achieves latency as low as 150ms while maintaining high-quality audio output.
20
+ - **Instruct Support**: Supports various instructions such as languages, dialects, emotions, speed, volume, etc.
21
+
22
+
23
+ ## Roadmap
24
+
25
+ - [x] 2025/12
26
+
27
+ - [x] release Fun-CosyVoice3-0.5B-2512 base model, rl model and its training/inference script
28
+ - [x] release Fun-CosyVoice3-0.5B modelscope gradio space
29
+
30
+ - [x] 2025/08
31
+
32
+ - [x] Thanks to the contribution from NVIDIA Yuekai Zhang, add triton trtllm runtime support and cosyvoice2 grpo training support
33
+
34
+ - [x] 2025/07
35
+
36
+ - [x] release Fun-CosyVoice 3.0 eval set
37
+
38
+ - [x] 2025/05
39
+
40
+ - [x] add CosyVoice2-0.5B vllm support
41
+
42
+ - [x] 2024/12
43
+
44
+ - [x] 25hz CosyVoice2-0.5B released
45
+
46
+ - [x] 2024/09
47
+
48
+ - [x] 25hz CosyVoice-300M base model
49
+ - [x] 25hz CosyVoice-300M voice conversion function
50
+
51
+ - [x] 2024/08
52
+
53
+ - [x] Repetition Aware Sampling(RAS) inference for llm stability
54
+ - [x] Streaming inference mode support, including kv cache and sdpa for rtf optimization
55
+
56
+ - [x] 2024/07
57
+
58
+ - [x] Flow matching training support
59
+ - [x] WeTextProcessing support when ttsfrd is not available
60
+ - [x] Fastapi server and client
61
+
62
+ ## Evaluation
63
+ | Model | CER (%) ↓ (test-zh) | WER (%) ↓ (test-en) | CER (%) ↓ (test-hard) |
64
+ |-----|------------------|------------------|------------------|
65
+ | Human | 1.26 | 2.14 | - |
66
+ | F5-TTS | 1.53 | 2.00 | 8.67 |
67
+ | SparkTTS | 1.20 | 1.98 | - |
68
+ | Seed-TTS | 1.12 | 2.25 | 7.59 |
69
+ | CosyVoice2 | 1.45 | 2.57 | 6.83 |
70
+ | FireRedTTS-2 | 1.14 | 1.95 | - |
71
+ | IndexTTS2 | 1.01 | 1.52 | 7.12 |
72
+ | VibeVoice | 1.16 | 3.04 | - |
73
+ | HiggsAudio | 1.79 | 2.44 | - |
74
+ | MiniMax-Speech | 0.83 | 1.65 | - |
75
+ | VoxPCM | 0.93 | 1.85 | 8.87 |
76
+ | GLM-TTS | 1.03 | - | - |
77
+ | GLM-TTS_RL | 0.89 | - | - |
78
+ | Fun-CosyVoice3-0.5B-2512 | 1.21 | 2.24 | 6.71 |
79
+ | Fun-CosyVoice3-0.5B-2512_RL | 0.81 | 1.68 | 5.44 |
80
+
81
+
82
+ ## Install
83
+
84
+ ### Clone and install
85
+
86
+ - Clone the repo
87
+ ``` sh
88
+ git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
89
+ # If you failed to clone the submodule due to network failures, please run the following command until success
90
+ cd CosyVoice
91
+ git submodule update --init --recursive
92
+ ```
93
+
94
+ - Install Conda: please see https://docs.conda.io/en/latest/miniconda.html
95
+ - Create Conda env:
96
+
97
+ ``` sh
98
+ conda create -n cosyvoice -y python=3.10
99
+ conda activate cosyvoice
100
+ pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
101
+
102
+ # If you encounter sox compatibility issues
103
+ # ubuntu
104
+ sudo apt-get install sox libsox-dev
105
+ # centos
106
+ sudo yum install sox sox-devel
107
+ ```
108
+
109
+ ### Model download
110
+
111
+ We strongly recommend that you download our pretrained `Fun-CosyVoice3-0.5B` `CosyVoice2-0.5B` `CosyVoice-300M` `CosyVoice-300M-SFT` `CosyVoice-300M-Instruct` model and `CosyVoice-ttsfrd` resource.
112
+
113
+ ``` python
114
+ # SDK模型下载
115
+ from modelscope import snapshot_download
116
+ snapshot_download('FunAudioLLM/Fun-CosyVoice3-0.5B-2512', local_dir='pretrained_models/Fun-CosyVoice3-0.5B')
117
+ snapshot_download('iic/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
118
+ snapshot_download('iic/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
119
+ snapshot_download('iic/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
120
+ snapshot_download('iic/CosyVoice-300M-Instruct', local_dir='pretrained_models/CosyVoice-300M-Instruct')
121
+ snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
122
+ ```
123
+
124
+ Optionally, you can unzip `ttsfrd` resource and install `ttsfrd` package for better text normalization performance.
125
+
126
+ Notice that this step is not necessary. If you do not install `ttsfrd` package, we will use wetext by default.
127
+
128
+ ``` sh
129
+ cd pretrained_models/CosyVoice-ttsfrd/
130
+ unzip resource.zip -d .
131
+ pip install ttsfrd_dependency-0.1-py3-none-any.whl
132
+ pip install ttsfrd-0.4.2-cp310-cp310-linux_x86_64.whl
133
+ ```
134
+
135
+ ### Basic Usage
136
+
137
+ We strongly recommend using `Fun-CosyVoice3-0.5B` for better performance.
138
+ Follow the code in `example.py` for detailed usage of each model.
139
+ ```sh
140
+ python example.py
141
+ ```
142
+
143
+ #### CosyVoice2 vllm Usage
144
+ If you want to use vllm for inference, please install `vllm==v0.9.0`. Older vllm version do not support CosyVoice2 inference.
145
+
146
+ Notice that `vllm==v0.9.0` has a lot of specific requirements, for example `torch==2.7.0`. You can create a new env to in case your hardward do not support vllm and old env is corrupted.
147
+
148
+ ``` sh
149
+ conda create -n cosyvoice_vllm --clone cosyvoice
150
+ conda activate cosyvoice_vllm
151
+ pip install vllm==v0.9.0 transformers==4.51.3 -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
152
+ python vllm_example.py
153
+ ```
154
+
155
+ #### Start web demo
156
+
157
+ You can use our web demo page to get familiar with CosyVoice quickly.
158
+
159
+ Please see the demo website for details.
160
+
161
+ ``` python
162
+ # change iic/CosyVoice-300M-SFT for sft inference, or iic/CosyVoice-300M-Instruct for instruct inference
163
+ python3 webui.py --port 50000 --model_dir pretrained_models/CosyVoice-300M
164
+ ```
165
+
166
+ #### Advanced Usage
167
+
168
+ For advanced users, we have provided training and inference scripts in `examples/libritts/cosyvoice/run.sh`.
169
+
170
+ #### Build for deployment
171
+
172
+ Optionally, if you want service deployment,
173
+ You can run the following steps.
174
+
175
+ ``` sh
176
+ cd runtime/python
177
+ docker build -t cosyvoice:v1.0 .
178
+ # change iic/CosyVoice-300M to iic/CosyVoice-300M-Instruct if you want to use instruct inference
179
+ # for grpc usage
180
+ docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity"
181
+ cd grpc && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
182
+ # for fastapi usage
183
+ docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi && python3 server.py --port 50000 --model_dir iic/CosyVoice-300M && sleep infinity"
184
+ cd fastapi && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
185
+ ```
186
+
187
+ #### Using Nvidia TensorRT-LLM for deployment
188
+
189
+ Using TensorRT-LLM to accelerate cosyvoice2 llm could give 4x acceleration comparing with huggingface transformers implementation.
190
+ To quick start:
191
+
192
+ ``` sh
193
+ cd runtime/triton_trtllm
194
+ docker compose up -d
195
+ ```
196
+ For more details, you could check [here](https://github.com/FunAudioLLM/CosyVoice/tree/main/runtime/triton_trtllm)
197
+
198
+ ## Discussion & Communication
199
+
200
+ You can directly discuss on [Github Issues](https://github.com/FunAudioLLM/CosyVoice/issues).
201
+
202
+ You can also scan the QR code to join our official Dingding chat group.
203
+
204
+ <img src="./asset/dingding.png" width="250px">
205
+
206
+ ## Acknowledge
207
+
208
+ 1. We borrowed a lot of code from [FunASR](https://github.com/modelscope/FunASR).
209
+ 2. We borrowed a lot of code from [FunCodec](https://github.com/modelscope/FunCodec).
210
+ 3. We borrowed a lot of code from [Matcha-TTS](https://github.com/shivammehta25/Matcha-TTS).
211
+ 4. We borrowed a lot of code from [AcademiCodec](https://github.com/yangdongchao/AcademiCodec).
212
+ 5. We borrowed a lot of code from [WeNet](https://github.com/wenet-e2e/wenet).
213
+
214
+ ## Citations
215
+
216
+ ``` bibtex
217
+ @article{du2024cosyvoice,
218
+ title={Cosyvoice: A scalable multilingual zero-shot text-to-speech synthesizer based on supervised semantic tokens},
219
+ author={Du, Zhihao and Chen, Qian and Zhang, Shiliang and Hu, Kai and Lu, Heng and Yang, Yexin and Hu, Hangrui and Zheng, Siqi and Gu, Yue and Ma, Ziyang and others},
220
+ journal={arXiv preprint arXiv:2407.05407},
221
+ year={2024}
222
+ }
223
+
224
+ @article{du2024cosyvoice,
225
+ title={Cosyvoice 2: Scalable streaming speech synthesis with large language models},
226
+ author={Du, Zhihao and Wang, Yuxuan and Chen, Qian and Shi, Xian and Lv, Xiang and Zhao, Tianyu and Gao, Zhifu and Yang, Yexin and Gao, Changfeng and Wang, Hui and others},
227
+ journal={arXiv preprint arXiv:2412.10117},
228
+ year={2024}
229
+ }
230
+
231
+ @article{du2025cosyvoice,
232
+ title={CosyVoice 3: Towards In-the-wild Speech Generation via Scaling-up and Post-training},
233
+ author={Du, Zhihao and Gao, Changfeng and Wang, Yuxuan and Yu, Fan and Zhao, Tianyu and Wang, Hao and Lv, Xiang and Wang, Hui and Shi, Xian and An, Keyu and others},
234
+ journal={arXiv preprint arXiv:2505.17589},
235
+ year={2025}
236
+ }
237
+
238
+ @inproceedings{lyu2025build,
239
+ title={Build LLM-Based Zero-Shot Streaming TTS System with Cosyvoice},
240
+ author={Lyu, Xiang and Wang, Yuxuan and Zhao, Tianyu and Wang, Hao and Liu, Huadai and Du, Zhihao},
241
+ booktitle={ICASSP 2025-2025 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
242
+ pages={1--2},
243
+ year={2025},
244
+ organization={IEEE}
245
+ }
246
+ ```
247
+
248
+ ## Disclaimer
249
+ The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal.
app.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Liu Yue)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import sys
16
+ import argparse
17
+ import gradio as gr
18
+ import numpy as np
19
+ import torch
20
+ import torchaudio
21
+ import random
22
+ import librosa
23
+ from funasr import AutoModel
24
+ from funasr.utils.postprocess_utils import rich_transcription_postprocess
25
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
26
+ sys.path.append('{}/third_party/Matcha-TTS'.format(ROOT_DIR))
27
+
28
+ from modelscope import snapshot_download, HubApi
29
+
30
+ api = HubApi()
31
+ _, cookies = api.login(access_token=os.environ['token'])
32
+ snapshot_download('FunAudioLLM/Fun-CosyVoice3-0.5B-2512', local_dir='pretrained_models/Fun-CosyVoice3-0.5B', cookies=cookies)
33
+ snapshot_download('iic/SenseVoiceSmall', local_dir='pretrained_models/SenseVoiceSmall', cookies=cookies)
34
+ snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd', cookies=cookies)
35
+ os.system('cd pretrained_models/CosyVoice-ttsfrd/ && pip install ttsfrd_dependency-0.1-py3-none-any.whl && pip install ttsfrd-0.4.2-cp310-cp310-linux_x86_64.whl && apt install -y unzip && rm -rf resource && unzip resource.zip -d .')
36
+
37
+ from cosyvoice.cli.cosyvoice import AutoModel as CosyVoiceAutoModel
38
+ from cosyvoice.utils.file_utils import logging, load_wav
39
+ from cosyvoice.utils.common import set_all_random_seed, instruct_list
40
+
41
+ inference_mode_list = ['3s极速复刻', '自然语言控制']
42
+ instruct_dict = {'3s极速复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 输入prompt文本\n3. 点击生成音频按钮',
43
+ '自然语言控制': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 输入instruct文本\n3. 点击生成音频按钮'}
44
+ stream_mode_list = [('否', False)]
45
+ max_val = 0.8
46
+
47
+
48
+ def generate_seed():
49
+ seed = random.randint(1, 100000000)
50
+ return {
51
+ "__type__": "update",
52
+ "value": seed
53
+ }
54
+
55
+ top_db = 60
56
+ hop_length = 220
57
+ win_length = 440
58
+ def postprocess(wav):
59
+ speech = load_wav(wav, target_sr=target_sr, min_sr=16000)
60
+ speech, _ = librosa.effects.trim(
61
+ speech, top_db=top_db,
62
+ frame_length=win_length,
63
+ hop_length=hop_length
64
+ )
65
+ if speech.abs().max() > max_val:
66
+ speech = speech / speech.abs().max() * max_val
67
+ speech = torch.concat([speech, torch.zeros(1, int(target_sr * 0.2))], dim=1)
68
+ torchaudio.save(wav, speech, target_sr)
69
+ return wav
70
+
71
+
72
+ def change_instruction(mode_checkbox_group):
73
+ return instruct_dict[mode_checkbox_group]
74
+
75
+ def prompt_wav_recognition(prompt_wav):
76
+ res = asr_model.generate(input=prompt_wav,
77
+ language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech"
78
+ use_itn=True,
79
+ )
80
+ text = res[0]["text"].split('|>')[-1]
81
+ return text
82
+
83
+ def generate_audio(tts_text, mode_checkbox_group, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text,
84
+ seed, stream):
85
+ stream = False
86
+ if len(tts_text) > 200:
87
+ gr.Warning('您输入的文字过长,请限制在200字以内')
88
+ return (target_sr, default_data)
89
+ sft_dropdown, speed = '', 1.0
90
+ if prompt_wav_upload is not None:
91
+ prompt_wav = prompt_wav_upload
92
+ elif prompt_wav_record is not None:
93
+ prompt_wav = prompt_wav_record
94
+ else:
95
+ prompt_wav = None
96
+ # if instruct mode, please make sure that model is iic/CosyVoice-300M-Instruct and not cross_lingual mode
97
+ if mode_checkbox_group in ['自然语言控制']:
98
+ if instruct_text == '':
99
+ gr.Warning('您正在使用自然语言控制模式, 请输入instruct文本')
100
+ return (target_sr, default_data)
101
+ if prompt_wav is None:
102
+ gr.Info('您正在使用自然语言控制模式, 请输入prompt音频')
103
+ return (target_sr, default_data)
104
+ # if in zero_shot cross_lingual, please make sure that prompt_text and prompt_wav meets requirements
105
+ if mode_checkbox_group in ['3s极速复刻', '跨语种复刻']:
106
+ if prompt_wav is None:
107
+ gr.Warning('prompt音频为空,您是否忘记输入prompt音频?')
108
+ return (target_sr, default_data)
109
+ info = torchaudio.info(prompt_wav)
110
+ if info.sample_rate < prompt_sr:
111
+ gr.Warning('prompt音频采样率{}低于{}'.format(torchaudio.info(prompt_wav).sample_rate, prompt_sr))
112
+ return (target_sr, default_data)
113
+ if info.num_frames / info.sample_rate > 10:
114
+ gr.Warning('请限制输入音频在10s内,避免推理效果过低')
115
+ return (target_sr, default_data)
116
+ # zero_shot mode only use prompt_wav prompt text
117
+ if mode_checkbox_group in ['3s极速复刻']:
118
+ if prompt_text == '':
119
+ gr.Warning('prompt文本为空,您是否忘记输入prompt文本?')
120
+ return (target_sr, default_data)
121
+ if instruct_text != '':
122
+ gr.Info('您正在使用3s极速复刻模式,instruct文本会被忽略!')
123
+ info = torchaudio.info(prompt_wav)
124
+ if info.num_frames / info.sample_rate > 10:
125
+ gr.Warning('请限制输入音频在10s内,避免推理效果过低')
126
+ return (target_sr, default_data)
127
+ if mode_checkbox_group == '3s极速复刻':
128
+ logging.info('get zero_shot inference request')
129
+ set_all_random_seed(seed)
130
+ speech_list = []
131
+ for i in cosyvoice.inference_zero_shot(tts_text, 'You are a helpful assistant.<|endofprompt|>' + prompt_text, postprocess(prompt_wav), stream=stream, speed=speed):
132
+ speech_list.append(i['tts_speech'])
133
+ return (target_sr, torch.concat(speech_list, dim=1).numpy().flatten())
134
+ elif mode_checkbox_group == '自然语言控制':
135
+ logging.info('get instruct inference request')
136
+ set_all_random_seed(seed)
137
+ speech_list = []
138
+ for i in cosyvoice.inference_instruct2(tts_text, instruct_text, postprocess(prompt_wav), stream=stream, speed=speed):
139
+ speech_list.append(i['tts_speech'])
140
+ return (target_sr, torch.concat(speech_list, dim=1).numpy().flatten())
141
+ else:
142
+ gr.Warning('无效的模式选择')
143
+
144
+
145
+ def main():
146
+ with gr.Blocks() as demo:
147
+ gr.Markdown("### 代码库 [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) \
148
+ 预训练模型 [Fun-CosyVoice3-0.5B](https://www.modelscope.cn/models/FunAudioLLM/Fun-CosyVoice3-0.5B) \
149
+ [CosyVoice2-0.5B](https://www.modelscope.cn/models/iic/CosyVoice2-0.5B) \
150
+ [CosyVoice-300M](https://www.modelscope.cn/models/iic/CosyVoice-300M) \
151
+ [CosyVoice-300M-Instruct](https://www.modelscope.cn/models/iic/CosyVoice-300M-Instruct) \
152
+ [CosyVoice-300M-SFT](https://www.modelscope.cn/models/iic/CosyVoice-300M-SFT)")
153
+ gr.Markdown("#### 请输入需要合成的文本,选择推理模式,并按照提示步骤进行操作")
154
+
155
+ tts_text = gr.Textbox(label="输入合成文本", lines=1, value="Her handwriting is [M][AY0][N][UW1][T]并且很整洁,说明她[h][ào]干净。")
156
+ with gr.Row():
157
+ mode_checkbox_group = gr.Radio(choices=inference_mode_list, label='选择推理模式', value=inference_mode_list[0])
158
+ instruction_text = gr.Text(label="操作步骤", value=instruct_dict[inference_mode_list[0]], scale=0.5)
159
+ stream = gr.Radio(choices=stream_mode_list, label='是否流式推理', value=stream_mode_list[0][1])
160
+ with gr.Column(scale=0.25):
161
+ seed_button = gr.Button(value="\U0001F3B2")
162
+ seed = gr.Number(value=0, label="随机推理种子")
163
+
164
+ with gr.Row():
165
+ prompt_wav_upload = gr.Audio(sources='upload', type='filepath', label='选择prompt音频文件,注意采样率不低于16khz')
166
+ prompt_wav_record = gr.Audio(sources='microphone', type='filepath', label='录制prompt音频文件')
167
+ prompt_text = gr.Textbox(label="prompt文本", lines=1, placeholder="请输入prompt文本,支持自动识别,您可以自行修正识别结果...", value='')
168
+ instruct_text = gr.Dropdown(choices=instruct_list, label='选择instruct文本', value=instruct_list[0])
169
+
170
+ generate_button = gr.Button("生成音频")
171
+
172
+ audio_output = gr.Audio(label="合成音频", autoplay=True, streaming=False)
173
+
174
+ seed_button.click(generate_seed, inputs=[], outputs=seed)
175
+ generate_button.click(generate_audio,
176
+ inputs=[tts_text, mode_checkbox_group, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text,
177
+ seed, stream],
178
+ outputs=[audio_output])
179
+ mode_checkbox_group.change(fn=change_instruction, inputs=[mode_checkbox_group], outputs=[instruction_text])
180
+ prompt_wav_upload.change(fn=prompt_wav_recognition, inputs=[prompt_wav_upload], outputs=[prompt_text])
181
+ prompt_wav_record.change(fn=prompt_wav_recognition, inputs=[prompt_wav_record], outputs=[prompt_text])
182
+ demo.queue(default_concurrency_limit=4).launch(server_port=50000, server_name='0.0.0.0')
183
+
184
+
185
+ if __name__ == '__main__':
186
+ cosyvoice = CosyVoiceAutoModel(model_dir='pretrained_models/Fun-CosyVoice3-0.5B', load_trt=True, fp16=False)
187
+ sft_spk = cosyvoice.list_available_spks()
188
+ for stream in [False]:
189
+ for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄��的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', 'You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。', 'zero_shot_prompt.wav', stream=stream)):
190
+ continue
191
+ prompt_sr, target_sr = 16000, 24000
192
+ default_data = np.zeros(target_sr)
193
+
194
+ model_dir = "pretrained_models/SenseVoiceSmall"
195
+ asr_model = AutoModel(
196
+ model=model_dir,
197
+ disable_update=True,
198
+ log_level='DEBUG',
199
+ device="cuda:0")
200
+ main()
cosyvoice/__init__.py ADDED
File without changes
cosyvoice/bin/average_model.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc (Di Wu)
2
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import argparse
18
+ import glob
19
+
20
+ import yaml
21
+ import torch
22
+
23
+
24
+ def get_args():
25
+ parser = argparse.ArgumentParser(description='average model')
26
+ parser.add_argument('--dst_model', required=True, help='averaged model')
27
+ parser.add_argument('--src_path',
28
+ required=True,
29
+ help='src model path for average')
30
+ parser.add_argument('--val_best',
31
+ action="store_true",
32
+ help='averaged model')
33
+ parser.add_argument('--num',
34
+ default=5,
35
+ type=int,
36
+ help='nums for averaged model')
37
+
38
+ args = parser.parse_args()
39
+ print(args)
40
+ return args
41
+
42
+
43
+ def main():
44
+ args = get_args()
45
+ val_scores = []
46
+ if args.val_best:
47
+ yamls = glob.glob('{}/*.yaml'.format(args.src_path))
48
+ yamls = [
49
+ f for f in yamls
50
+ if not (os.path.basename(f).startswith('train')
51
+ or os.path.basename(f).startswith('init'))
52
+ ]
53
+ for y in yamls:
54
+ with open(y, 'r') as f:
55
+ dic_yaml = yaml.load(f, Loader=yaml.BaseLoader)
56
+ loss = float(dic_yaml['loss_dict']['loss'])
57
+ epoch = int(dic_yaml['epoch'])
58
+ step = int(dic_yaml['step'])
59
+ tag = dic_yaml['tag']
60
+ val_scores += [[epoch, step, loss, tag]]
61
+ sorted_val_scores = sorted(val_scores,
62
+ key=lambda x: x[2],
63
+ reverse=False)
64
+ print("best val (epoch, step, loss, tag) = " +
65
+ str(sorted_val_scores[:args.num]))
66
+ path_list = [
67
+ args.src_path + '/epoch_{}_whole.pt'.format(score[0])
68
+ for score in sorted_val_scores[:args.num]
69
+ ]
70
+ print(path_list)
71
+ avg = {}
72
+ num = args.num
73
+ assert num == len(path_list)
74
+ for path in path_list:
75
+ print('Processing {}'.format(path))
76
+ states = torch.load(path, map_location=torch.device('cpu'))
77
+ for k in states.keys():
78
+ if k not in ['step', 'epoch']:
79
+ if k not in avg.keys():
80
+ avg[k] = states[k].clone()
81
+ else:
82
+ avg[k] += states[k]
83
+ # average
84
+ for k in avg.keys():
85
+ if avg[k] is not None:
86
+ # pytorch 1.6 use true_divide instead of /=
87
+ avg[k] = torch.true_divide(avg[k], num)
88
+ print('Saving to {}'.format(args.dst_model))
89
+ torch.save(avg, args.dst_model)
90
+
91
+
92
+ if __name__ == '__main__':
93
+ main()
cosyvoice/bin/convert.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+
4
+ def convert_llm(state_dict):
5
+ # 调整了lm的结构,把codec_lm.encoder作为llm,codec_lm.decoder作为decoder
6
+ keys = list(state_dict.keys())
7
+ for k in keys:
8
+ if k.startswith('codec_lm.encoder.'):
9
+ v = state_dict.pop(k)
10
+ k = k.replace('codec_lm.encoder.', 'llm.')
11
+ state_dict[k] = v
12
+ if k.startswith('codec_lm.decoder.'):
13
+ v = state_dict.pop(k)
14
+ k = k.replace('codec_lm.decoder.', 'llm_decoder.')
15
+ state_dict[k] = v
16
+ # espnet和wenet具体实现上的差异
17
+ keys = list(state_dict.keys())
18
+ for k in keys:
19
+ if k.startswith('text_encoder.embed.'):
20
+ v = state_dict.pop(k)
21
+ k = k.replace('text_encoder.embed.', 'text_encoder.embed.out.')
22
+ state_dict[k] = v
23
+ if k.startswith('llm.embed.'):
24
+ v = state_dict.pop(k)
25
+ k = k.replace('llm.embed.', 'llm.embed.out.')
26
+ state_dict[k] = v
27
+ keys = list(state_dict.keys())
28
+ for k in keys:
29
+ if k.startswith('text_enc_out_layer.'):
30
+ v = state_dict.pop(k)
31
+ k = k.replace('text_enc_out_layer.', 'text_encoder_affine_layer.')
32
+ state_dict[k] = v
33
+ if k.startswith('token_embedding.'):
34
+ v = state_dict.pop(k)
35
+ k = k.replace('token_embedding.', 'text_embedding.')
36
+ state_dict[k] = v
37
+ if k.startswith('xvec_proj.'):
38
+ v = state_dict.pop(k)
39
+ k = k.replace('xvec_proj.', 'spk_embed_affine_layer.')
40
+ state_dict[k] = v
41
+ if k.startswith('lm_embedding.'):
42
+ v = state_dict.pop(k)
43
+ k = k.replace('lm_embedding.', 'llm_embedding.')
44
+ state_dict[k] = v
45
+ if k.startswith('codec_embedder.'):
46
+ v = state_dict.pop(k)
47
+ k = k.replace('codec_embedder.', 'speech_embedding.')
48
+ state_dict[k] = v
49
+ # instruct少了spk embedding参数,加个全0上去
50
+ keys = list(state_dict.keys())
51
+ if 'spk_embed_affine_layer.weight' not in keys:
52
+ print('no spk_embed_affine_layer.weight, should be instruct model')
53
+ state_dict['spk_embed_affine_layer.weight'] = torch.zeros(1024, 192)
54
+ if 'spk_embed_affine_layer.bias' not in keys:
55
+ print('no spk_embed_affine_layer.bias, should be instruct model')
56
+ state_dict['spk_embed_affine_layer.bias'] = torch.zeros(1024)
57
+ return state_dict
58
+
59
+ def convert_hift(state_dict):
60
+ # 调整了cosyvoice中hifigan的结构,把f0_predictor放到generator里
61
+ state_dict = {k: v for k, v in state_dict.items() if not k.startswith('discriminator.')}
62
+ keys = list(state_dict.keys())
63
+ for k in keys:
64
+ if k in ['step', 'epoch']:
65
+ del state_dict[k]
66
+ if k.startswith('decoder.'):
67
+ v = state_dict.pop(k)
68
+ k = k.replace('decoder.', '')
69
+ state_dict[k] = v
70
+ if k.startswith('generator.'):
71
+ v = state_dict.pop(k)
72
+ k = k.replace('generator.', '')
73
+ state_dict[k] = v
74
+ return state_dict
75
+
76
+ def convert_flow(state_dict):
77
+ keys = list(state_dict.keys())
78
+ for k in keys:
79
+ if k.startswith('encoder.embed.'):
80
+ v = state_dict.pop(k)
81
+ k = k.replace('encoder.embed.', 'encoder.embed.out.')
82
+ state_dict[k] = v
83
+ for k in keys:
84
+ if k.startswith('xvec_proj.'):
85
+ v = state_dict.pop(k)
86
+ k = k.replace('xvec_proj.', 'spk_embed_affine_layer.')
87
+ state_dict[k] = v
88
+ return state_dict
89
+
90
+ def convert_llm2(state_dict):
91
+ # 调整了lm的结构,把codec_lm.encoder作为llm,codec_lm.decoder作为decoder
92
+ keys = list(state_dict.keys())
93
+ for k in keys:
94
+ if k.startswith('codec_lm.encoder.'):
95
+ v = state_dict.pop(k)
96
+ k = k.replace('codec_lm.encoder.', 'llm.')
97
+ state_dict[k] = v
98
+ if k.startswith('codec_lm.decoder.'):
99
+ v = state_dict.pop(k)
100
+ k = k.replace('codec_lm.decoder.', 'llm_decoder.')
101
+ state_dict[k] = v
102
+ if k.startswith('lm_embedding.'):
103
+ v = state_dict.pop(k)
104
+ k = k.replace('lm_embedding.', 'llm_embedding.')
105
+ state_dict[k] = v
106
+ if k.startswith('codec_embedder.'):
107
+ v = state_dict.pop(k)
108
+ k = k.replace('codec_embedder.', 'speech_embedding.')
109
+ state_dict[k] = v
110
+ if k.startswith('text_enc_out_layer.'):
111
+ state_dict.pop(k)
112
+ if k.startswith('token_embedding.weight'):
113
+ state_dict.pop(k)
114
+ return state_dict
115
+
116
+ def convert_llm3(state_dict):
117
+ # 调整了lm的结构,把codec_lm.encoder作为llm,codec_lm.decoder作为decoder
118
+ keys = list(state_dict.keys())
119
+ state_dict = {k: v for k, v in state_dict.items() if (not k.startswith('reward') and not k.startswith('ref'))}
120
+ for k in keys:
121
+ if k.startswith('llm.model.'):
122
+ v = state_dict.pop(k)
123
+ k = k.replace('llm.model.', 'llm.model.model.')
124
+ state_dict[k] = v
125
+ if k.startswith('codec_head.'):
126
+ v = state_dict.pop(k)
127
+ state_dict[k.replace('codec_head.', 'llm_decoder.')] = v
128
+ if k.startswith('codec_embed.'):
129
+ v = state_dict.pop(k)
130
+ k = k.replace('codec_embed.', 'speech_embedding.')
131
+ state_dict[k] = v
132
+ state_dict['llm.model.lm_head.weight'] = state_dict['llm.model.model.embed_tokens.weight']
133
+ return state_dict
134
+
135
+ def convert_flow2(state_dict):
136
+ keys = list(state_dict.keys())
137
+ for k in keys:
138
+ if k.startswith('encoder.embed.'):
139
+ v = state_dict.pop(k)
140
+ k = k.replace('encoder.embed.', 'encoder.embed.out.')
141
+ state_dict[k] = v
142
+ for k in keys:
143
+ if k.startswith('xvec_proj.'):
144
+ v = state_dict.pop(k)
145
+ k = k.replace('xvec_proj.', 'spk_embed_affine_layer.')
146
+ state_dict[k] = v
147
+ for k in keys:
148
+ if k.startswith('mel_extractor.'):
149
+ state_dict.pop(k)
150
+ for k in keys:
151
+ if k.startswith('encoder.upsample_blocks.0.0.'):
152
+ v = state_dict.pop(k)
153
+ k = k.replace('encoder.upsample_blocks.0.0.', 'encoder.up_layer.')
154
+ state_dict[k] = v
155
+ if k.startswith('encoder.upsample_blocks.0.1.'):
156
+ v = state_dict.pop(k)
157
+ k = k.replace('encoder.upsample_blocks.0.1.', 'encoder.up_embed.out.')
158
+ state_dict[k] = v
159
+ if k.startswith('encoder.upsample_blocks.0.2.'):
160
+ v = state_dict.pop(k)
161
+ k = k.replace('encoder.upsample_blocks.0.2.', 'encoder.up_encoders.')
162
+ state_dict[k] = v
163
+ # CausalBlock1D中sequantial 1->2
164
+ if k.startswith('decoder.estimator.') and k.endswith('block.1.weight'):
165
+ v = state_dict.pop(k)
166
+ k = k.replace('block.1.weight', 'block.2.weight')
167
+ state_dict[k] = v
168
+ if k.startswith('decoder.estimator.') and k.endswith('block.1.bias'):
169
+ v = state_dict.pop(k)
170
+ k = k.replace('block.1.bias', 'block.2.bias')
171
+ state_dict[k] = v
172
+ return state_dict
173
+
174
+ def convert_flow3(state_dict):
175
+ keys = list(state_dict.keys())
176
+ for k in keys:
177
+ if k.startswith('xvec_proj.'):
178
+ v = state_dict.pop(k)
179
+ k = k.replace('xvec_proj.', 'spk_embed_affine_layer.')
180
+ state_dict[k] = v
181
+ if k.startswith('codec_embedder.'):
182
+ v = state_dict.pop(k)
183
+ k = k.replace('codec_embedder.', 'input_embedding.')
184
+ state_dict[k] = v
185
+ if k.startswith('lookahead_conv1d.'):
186
+ v = state_dict.pop(k)
187
+ k = k.replace('lookahead_conv1d.', 'pre_lookahead_layer.')
188
+ state_dict[k] = v
189
+ for k in keys:
190
+ if k.startswith('mel_extractor.'):
191
+ state_dict.pop(k)
192
+ for k in keys:
193
+ # CausalBlock1D中sequantial 1->2
194
+ if k.startswith('dit_model.'):
195
+ v = state_dict.pop(k)
196
+ k = k.replace('dit_model.', 'decoder.estimator.')
197
+ state_dict[k] = v
198
+ if k in ['epoch', 'step']:
199
+ state_dict.pop(k)
200
+ return state_dict
201
+
202
+ if __name__ == '__main__':
203
+ # 使用方法 python3 convert.py 原格式llm.pt llm 新格式llm.pt
204
+ state_dict = torch.load(sys.argv[1], map_location='cpu')
205
+ if 'state_dict' in state_dict:
206
+ state_dict = state_dict['state_dict']
207
+ if sys.argv[2] == 'llm':
208
+ state_dict = convert_llm(state_dict)
209
+ elif sys.argv[2] == 'flow':
210
+ state_dict = convert_flow(state_dict)
211
+ elif sys.argv[2] == 'hift':
212
+ state_dict = convert_hift(state_dict)
213
+ elif sys.argv[2] == 'llm2':
214
+ state_dict = convert_llm2(state_dict)
215
+ elif sys.argv[2] == 'llm3':
216
+ state_dict = convert_llm3(state_dict)
217
+ elif sys.argv[2] == 'flow2':
218
+ state_dict = convert_flow2(state_dict)
219
+ elif sys.argv[2] == 'flow3':
220
+ state_dict = convert_flow3(state_dict)
221
+ else:
222
+ raise ValueError
223
+ torch.save(state_dict, sys.argv[3])
cosyvoice/bin/export_jit.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+
17
+ import argparse
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ import os
21
+ import sys
22
+ import torch
23
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
24
+ sys.path.append('{}/../..'.format(ROOT_DIR))
25
+ sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
26
+ from cosyvoice.cli.cosyvoice import AutoModel
27
+ from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
28
+ from cosyvoice.utils.file_utils import logging
29
+ from cosyvoice.utils.class_utils import get_model_type
30
+
31
+
32
+ def get_args():
33
+ parser = argparse.ArgumentParser(description='export your model for deployment')
34
+ parser.add_argument('--model_dir',
35
+ type=str,
36
+ default='pretrained_models/CosyVoice-300M',
37
+ help='local path')
38
+ args = parser.parse_args()
39
+ print(args)
40
+ return args
41
+
42
+
43
+ def get_optimized_script(model, preserved_attrs=[]):
44
+ script = torch.jit.script(model)
45
+ if preserved_attrs != []:
46
+ script = torch.jit.freeze(script, preserved_attrs=preserved_attrs)
47
+ else:
48
+ script = torch.jit.freeze(script)
49
+ script = torch.jit.optimize_for_inference(script)
50
+ return script
51
+
52
+
53
+ def main():
54
+ args = get_args()
55
+ logging.basicConfig(level=logging.DEBUG,
56
+ format='%(asctime)s %(levelname)s %(message)s')
57
+
58
+ torch._C._jit_set_fusion_strategy([('STATIC', 1)])
59
+ torch._C._jit_set_profiling_mode(False)
60
+ torch._C._jit_set_profiling_executor(False)
61
+
62
+ model = AutoModel(model_dir=args.model_dir)
63
+
64
+ if get_model_type(model.model) == CosyVoiceModel:
65
+ # 1. export flow encoder
66
+ flow_encoder = model.model.flow.encoder
67
+ script = get_optimized_script(flow_encoder)
68
+ script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
69
+ script = get_optimized_script(flow_encoder.half())
70
+ script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
71
+ logging.info('successfully export flow_encoder')
72
+ elif get_model_type(model.model) == CosyVoice2Model:
73
+ # 1. export llm text_encoder
74
+ llm_text_encoder = model.model.llm.text_encoder
75
+ script = get_optimized_script(llm_text_encoder)
76
+ script.save('{}/llm.text_encoder.fp32.zip'.format(args.model_dir))
77
+ script = get_optimized_script(llm_text_encoder.half())
78
+ script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
79
+ logging.info('successfully export llm_text_encoder')
80
+
81
+ # 2. export llm llm
82
+ llm_llm = model.model.llm.llm
83
+ script = get_optimized_script(llm_llm, ['forward_chunk'])
84
+ script.save('{}/llm.llm.fp32.zip'.format(args.model_dir))
85
+ script = get_optimized_script(llm_llm.half(), ['forward_chunk'])
86
+ script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
87
+ logging.info('successfully export llm_llm')
88
+
89
+ # 3. export flow encoder
90
+ flow_encoder = model.model.flow.encoder
91
+ script = get_optimized_script(flow_encoder)
92
+ script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
93
+ script = get_optimized_script(flow_encoder.half())
94
+ script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
95
+ logging.info('successfully export flow_encoder')
96
+ else:
97
+ raise ValueError('unsupported model type')
98
+
99
+
100
+ if __name__ == '__main__':
101
+ main()
cosyvoice/bin/export_onnx.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, [email protected])
2
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import print_function
17
+
18
+ import argparse
19
+ import logging
20
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
21
+ import os
22
+ import sys
23
+ import onnxruntime
24
+ import random
25
+ import torch
26
+ from tqdm import tqdm
27
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
28
+ sys.path.append('{}/../..'.format(ROOT_DIR))
29
+ sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
30
+ from cosyvoice.cli.cosyvoice import AutoModel
31
+ from cosyvoice.utils.file_utils import logging
32
+
33
+
34
+ def get_dummy_input(batch_size, seq_len, out_channels, device):
35
+ x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
36
+ mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
37
+ mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
38
+ t = torch.rand((batch_size), dtype=torch.float32, device=device)
39
+ spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
40
+ cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
41
+ return x, mask, mu, t, spks, cond
42
+
43
+
44
+ def get_args():
45
+ parser = argparse.ArgumentParser(description='export your model for deployment')
46
+ parser.add_argument('--model_dir',
47
+ type=str,
48
+ default='pretrained_models/CosyVoice-300M',
49
+ help='local path')
50
+ args = parser.parse_args()
51
+ print(args)
52
+ return args
53
+
54
+
55
+ @torch.no_grad()
56
+ def main():
57
+ args = get_args()
58
+ logging.basicConfig(level=logging.DEBUG,
59
+ format='%(asctime)s %(levelname)s %(message)s')
60
+
61
+ model = AutoModel(model_dir=args.model_dir)
62
+
63
+ # 1. export flow decoder estimator
64
+ estimator = model.model.flow.decoder.estimator
65
+ estimator.eval()
66
+
67
+ device = model.model.device
68
+ batch_size, seq_len = 2, 256
69
+ out_channels = model.model.flow.decoder.estimator.out_channels
70
+ x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
71
+ torch.onnx.export(
72
+ estimator,
73
+ (x, mask, mu, t, spks, cond),
74
+ '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
75
+ export_params=True,
76
+ opset_version=18,
77
+ do_constant_folding=True,
78
+ input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
79
+ output_names=['estimator_out'],
80
+ dynamic_axes={
81
+ 'x': {2: 'seq_len'},
82
+ 'mask': {2: 'seq_len'},
83
+ 'mu': {2: 'seq_len'},
84
+ 'cond': {2: 'seq_len'},
85
+ 'estimator_out': {2: 'seq_len'},
86
+ }
87
+ )
88
+
89
+ # 2. test computation consistency
90
+ option = onnxruntime.SessionOptions()
91
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
92
+ option.intra_op_num_threads = 1
93
+ providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
94
+ estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
95
+ sess_options=option, providers=providers)
96
+
97
+ for _ in tqdm(range(10)):
98
+ x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
99
+ output_pytorch = estimator(x, mask, mu, t, spks, cond)
100
+ ort_inputs = {
101
+ 'x': x.cpu().numpy(),
102
+ 'mask': mask.cpu().numpy(),
103
+ 'mu': mu.cpu().numpy(),
104
+ 't': t.cpu().numpy(),
105
+ 'spks': spks.cpu().numpy(),
106
+ 'cond': cond.cpu().numpy()
107
+ }
108
+ output_onnx = estimator_onnx.run(None, ort_inputs)[0]
109
+ torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
110
+ logging.info('successfully export estimator')
111
+
112
+
113
+ if __name__ == "__main__":
114
+ main()
cosyvoice/bin/inference_deprecated.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+
17
+ import argparse
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ import os
21
+ import torch
22
+ from torch.utils.data import DataLoader
23
+ import torchaudio
24
+ from hyperpyyaml import load_hyperpyyaml
25
+ from tqdm import tqdm
26
+ from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
27
+ from cosyvoice.dataset.dataset import Dataset
28
+
29
+
30
+ def get_args():
31
+ parser = argparse.ArgumentParser(description='inference with your model')
32
+ parser.add_argument('--config', required=True, help='config file')
33
+ parser.add_argument('--prompt_data', required=True, help='prompt data file')
34
+ parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
35
+ parser.add_argument('--tts_text', required=True, help='tts input file')
36
+ parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
37
+ parser.add_argument('--llm_model', required=True, help='llm model file')
38
+ parser.add_argument('--flow_model', required=True, help='flow model file')
39
+ parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
40
+ parser.add_argument('--gpu',
41
+ type=int,
42
+ default=-1,
43
+ help='gpu id for this rank, -1 for cpu')
44
+ parser.add_argument('--mode',
45
+ default='sft',
46
+ choices=['sft', 'zero_shot'],
47
+ help='inference mode')
48
+ parser.add_argument('--result_dir', required=True, help='asr result file')
49
+ args = parser.parse_args()
50
+ print(args)
51
+ return args
52
+
53
+
54
+ def main():
55
+ args = get_args()
56
+ logging.basicConfig(level=logging.DEBUG,
57
+ format='%(asctime)s %(levelname)s %(message)s')
58
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
59
+
60
+ # Init cosyvoice models from configs
61
+ use_cuda = args.gpu >= 0 and torch.cuda.is_available()
62
+ device = torch.device('cuda' if use_cuda else 'cpu')
63
+ try:
64
+ with open(args.config, 'r') as f:
65
+ configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': args.qwen_pretrain_path})
66
+ model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'])
67
+ except Exception:
68
+ try:
69
+ with open(args.config, 'r') as f:
70
+ configs = load_hyperpyyaml(f)
71
+ model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
72
+ except Exception:
73
+ raise TypeError('no valid model_type!')
74
+
75
+ model.load(args.llm_model, args.flow_model, args.hifigan_model)
76
+
77
+ test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
78
+ tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
79
+ test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
80
+
81
+ sample_rate = configs['sample_rate']
82
+ del configs
83
+ os.makedirs(args.result_dir, exist_ok=True)
84
+ fn = os.path.join(args.result_dir, 'wav.scp')
85
+ f = open(fn, 'w')
86
+ with torch.no_grad():
87
+ for _, batch in tqdm(enumerate(test_data_loader)):
88
+ utts = batch["utts"]
89
+ assert len(utts) == 1, "inference mode only support batchsize 1"
90
+ text_token = batch["text_token"].to(device)
91
+ text_token_len = batch["text_token_len"].to(device)
92
+ tts_index = batch["tts_index"]
93
+ tts_text_token = batch["tts_text_token"].to(device)
94
+ tts_text_token_len = batch["tts_text_token_len"].to(device)
95
+ speech_token = batch["speech_token"].to(device)
96
+ speech_token_len = batch["speech_token_len"].to(device)
97
+ speech_feat = batch["speech_feat"].to(device)
98
+ speech_feat_len = batch["speech_feat_len"].to(device)
99
+ utt_embedding = batch["utt_embedding"].to(device)
100
+ spk_embedding = batch["spk_embedding"].to(device)
101
+ if args.mode == 'sft':
102
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
103
+ 'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
104
+ else:
105
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
106
+ 'prompt_text': text_token, 'prompt_text_len': text_token_len,
107
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
108
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
109
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
110
+ 'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
111
+ tts_speeches = []
112
+ for model_output in model.tts(**model_input):
113
+ tts_speeches.append(model_output['tts_speech'])
114
+ tts_speeches = torch.concat(tts_speeches, dim=1)
115
+ tts_key = '{}_{}'.format(utts[0], tts_index[0])
116
+ tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
117
+ torchaudio.save(tts_fn, tts_speeches, sample_rate=sample_rate, backend='soundfile')
118
+ f.write('{} {}\n'.format(tts_key, tts_fn))
119
+ f.flush()
120
+ f.close()
121
+ logging.info('Result wav.scp saved in {}'.format(fn))
122
+
123
+
124
+ if __name__ == '__main__':
125
+ logging.warning('this code has been deprecated, please refer to README for CosyVoice inference usage!')
126
+ main()
cosyvoice/bin/train.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+ import argparse
17
+ import datetime
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ from copy import deepcopy
21
+ import os
22
+ import torch
23
+ import torch.distributed as dist
24
+ import deepspeed
25
+
26
+ from hyperpyyaml import load_hyperpyyaml
27
+
28
+ from torch.distributed.elastic.multiprocessing.errors import record
29
+
30
+ from cosyvoice.utils.losses import DPOLoss
31
+ from cosyvoice.utils.executor import Executor
32
+ from cosyvoice.utils.train_utils import (
33
+ init_distributed,
34
+ init_dataset_and_dataloader,
35
+ init_optimizer_and_scheduler,
36
+ init_summarywriter, save_model,
37
+ wrap_cuda_model, check_modify_and_save_config)
38
+
39
+
40
+ def get_args():
41
+ parser = argparse.ArgumentParser(description='training your network')
42
+ parser.add_argument('--train_engine',
43
+ default='torch_ddp',
44
+ choices=['torch_ddp', 'deepspeed'],
45
+ help='Engine for paralleled training')
46
+ parser.add_argument('--model', required=True, help='model which will be trained')
47
+ parser.add_argument('--ref_model', required=False, help='ref model used in dpo')
48
+ parser.add_argument('--config', required=True, help='config file')
49
+ parser.add_argument('--train_data', required=True, help='train data file')
50
+ parser.add_argument('--cv_data', required=True, help='cv data file')
51
+ parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
52
+ parser.add_argument('--checkpoint', help='checkpoint model')
53
+ parser.add_argument('--model_dir', required=True, help='save model dir')
54
+ parser.add_argument('--tensorboard_dir',
55
+ default='tensorboard',
56
+ help='tensorboard log dir')
57
+ parser.add_argument('--ddp.dist_backend',
58
+ dest='dist_backend',
59
+ default='nccl',
60
+ choices=['nccl', 'gloo'],
61
+ help='distributed backend')
62
+ parser.add_argument('--num_workers',
63
+ default=0,
64
+ type=int,
65
+ help='num of subprocess workers for reading')
66
+ parser.add_argument('--prefetch',
67
+ default=100,
68
+ type=int,
69
+ help='prefetch number')
70
+ parser.add_argument('--pin_memory',
71
+ action='store_true',
72
+ default=False,
73
+ help='Use pinned memory buffers used for reading')
74
+ parser.add_argument('--use_amp',
75
+ action='store_true',
76
+ default=False,
77
+ help='Use automatic mixed precision training')
78
+ parser.add_argument('--dpo',
79
+ action='store_true',
80
+ default=False,
81
+ help='Use Direct Preference Optimization')
82
+ parser.add_argument('--deepspeed.save_states',
83
+ dest='save_states',
84
+ default='model_only',
85
+ choices=['model_only', 'model+optimizer'],
86
+ help='save model/optimizer states')
87
+ parser.add_argument('--timeout',
88
+ default=60,
89
+ type=int,
90
+ help='timeout (in seconds) of cosyvoice_join.')
91
+ parser = deepspeed.add_config_arguments(parser)
92
+ args = parser.parse_args()
93
+ return args
94
+
95
+
96
+ @record
97
+ def main():
98
+ args = get_args()
99
+ logging.basicConfig(level=logging.DEBUG,
100
+ format='%(asctime)s %(levelname)s %(message)s')
101
+ # gan train has some special initialization logic
102
+ gan = True if args.model == 'hifigan' else False
103
+
104
+ override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
105
+ if gan is True:
106
+ override_dict.pop('hift')
107
+ try:
108
+ with open(args.config, 'r') as f:
109
+ configs = load_hyperpyyaml(f, overrides={**override_dict, 'qwen_pretrain_path': args.qwen_pretrain_path})
110
+ except Exception:
111
+ with open(args.config, 'r') as f:
112
+ configs = load_hyperpyyaml(f, overrides=override_dict)
113
+ if gan is True:
114
+ configs['train_conf'] = configs['train_conf_gan']
115
+ configs['train_conf'].update(vars(args))
116
+
117
+ # Init env for ddp
118
+ init_distributed(args)
119
+
120
+ # Get dataset & dataloader
121
+ train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
122
+ init_dataset_and_dataloader(args, configs, gan, args.dpo)
123
+
124
+ # Do some sanity checks and save config to arsg.model_dir
125
+ configs = check_modify_and_save_config(args, configs)
126
+
127
+ # Tensorboard summary
128
+ writer = init_summarywriter(args)
129
+
130
+ # load checkpoint
131
+ if args.dpo is True:
132
+ configs[args.model].forward = configs[args.model].forward_dpo
133
+ model = configs[args.model]
134
+ start_step, start_epoch = 0, -1
135
+ if args.checkpoint is not None:
136
+ if os.path.exists(args.checkpoint):
137
+ state_dict = torch.load(args.checkpoint, map_location='cpu')
138
+ model.load_state_dict(state_dict, strict=False)
139
+ if 'step' in state_dict:
140
+ start_step = state_dict['step']
141
+ if 'epoch' in state_dict:
142
+ start_epoch = state_dict['epoch']
143
+ else:
144
+ logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
145
+
146
+ # Dispatch model from cpu to gpu
147
+ model = wrap_cuda_model(args, model)
148
+
149
+ # Get optimizer & scheduler
150
+ model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
151
+ scheduler.set_step(start_step)
152
+ if scheduler_d is not None:
153
+ scheduler_d.set_step(start_step)
154
+
155
+ # Save init checkpoints
156
+ info_dict = deepcopy(configs['train_conf'])
157
+ info_dict['step'] = start_step
158
+ info_dict['epoch'] = start_epoch
159
+ save_model(model, 'init', info_dict)
160
+
161
+ # DPO related
162
+ if args.dpo is True:
163
+ ref_model = deepcopy(configs[args.model])
164
+ state_dict = torch.load(args.ref_model, map_location='cpu')
165
+ ref_model.load_state_dict(state_dict, strict=False)
166
+ dpo_loss = DPOLoss(beta=0.01, label_smoothing=0.0, ipo=False)
167
+ # NOTE maybe it is not needed to wrap ref_model as ddp because its parameter is not updated
168
+ ref_model = wrap_cuda_model(args, ref_model)
169
+ else:
170
+ ref_model, dpo_loss = None, None
171
+
172
+ # Get executor
173
+ executor = Executor(gan=gan, ref_model=ref_model, dpo_loss=dpo_loss)
174
+ executor.step = start_step
175
+
176
+ # Init scaler, used for pytorch amp mixed precision training
177
+ scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
178
+ print('start step {} start epoch {}'.format(start_step, start_epoch))
179
+
180
+ # Start training loop
181
+ for epoch in range(start_epoch + 1, info_dict['max_epoch']):
182
+ executor.epoch = epoch
183
+ train_dataset.set_epoch(epoch)
184
+ dist.barrier()
185
+ group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
186
+ if gan is True:
187
+ executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
188
+ writer, info_dict, scaler, group_join)
189
+ else:
190
+ executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=ref_model)
191
+ dist.destroy_process_group(group_join)
192
+
193
+
194
+ if __name__ == '__main__':
195
+ main()
cosyvoice/cli/__init__.py ADDED
File without changes
cosyvoice/cli/cosyvoice.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import time
16
+ from typing import Generator
17
+ from tqdm import tqdm
18
+ from hyperpyyaml import load_hyperpyyaml
19
+ from modelscope import snapshot_download
20
+ import torch
21
+ from cosyvoice.cli.frontend import CosyVoiceFrontEnd
22
+ from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
23
+ from cosyvoice.utils.file_utils import logging
24
+ from cosyvoice.utils.class_utils import get_model_type
25
+
26
+
27
+ class CosyVoice:
28
+
29
+ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
30
+ self.model_dir = model_dir
31
+ self.fp16 = fp16
32
+ if not os.path.exists(model_dir):
33
+ model_dir = snapshot_download(model_dir)
34
+ hyper_yaml_path = '{}/cosyvoice.yaml'.format(model_dir)
35
+ if not os.path.exists(hyper_yaml_path):
36
+ raise ValueError('{} not found!'.format(hyper_yaml_path))
37
+ with open(hyper_yaml_path, 'r') as f:
38
+ configs = load_hyperpyyaml(f)
39
+ assert get_model_type(configs) == CosyVoiceModel, 'do not use {} for CosyVoice initialization!'.format(model_dir)
40
+ self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
41
+ configs['feat_extractor'],
42
+ '{}/campplus.onnx'.format(model_dir),
43
+ '{}/speech_tokenizer_v1.onnx'.format(model_dir),
44
+ '{}/spk2info.pt'.format(model_dir),
45
+ configs['allowed_special'])
46
+ self.sample_rate = configs['sample_rate']
47
+ if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
48
+ load_jit, load_trt, fp16 = False, False, False
49
+ logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
50
+ self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
51
+ self.model.load('{}/llm.pt'.format(model_dir),
52
+ '{}/flow.pt'.format(model_dir),
53
+ '{}/hift.pt'.format(model_dir))
54
+ if load_jit:
55
+ self.model.load_jit('{}/llm.text_encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
56
+ '{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
57
+ '{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
58
+ if load_trt:
59
+ self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
60
+ '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
61
+ trt_concurrent,
62
+ self.fp16)
63
+ del configs
64
+
65
+ def list_available_spks(self):
66
+ spks = list(self.frontend.spk2info.keys())
67
+ return spks
68
+
69
+ def add_zero_shot_spk(self, prompt_text, prompt_wav, zero_shot_spk_id):
70
+ assert zero_shot_spk_id != '', 'do not use empty zero_shot_spk_id'
71
+ model_input = self.frontend.frontend_zero_shot('', prompt_text, prompt_wav, self.sample_rate, '')
72
+ del model_input['text']
73
+ del model_input['text_len']
74
+ self.frontend.spk2info[zero_shot_spk_id] = model_input
75
+ return True
76
+
77
+ def save_spkinfo(self):
78
+ torch.save(self.frontend.spk2info, '{}/spk2info.pt'.format(self.model_dir))
79
+
80
+ def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
81
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
82
+ model_input = self.frontend.frontend_sft(i, spk_id)
83
+ start_time = time.time()
84
+ logging.info('synthesis text {}'.format(i))
85
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
86
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
87
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
88
+ yield model_output
89
+ start_time = time.time()
90
+
91
+ def inference_zero_shot(self, tts_text, prompt_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
92
+ prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
93
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
94
+ if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
95
+ logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
96
+ model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_wav, self.sample_rate, zero_shot_spk_id)
97
+ start_time = time.time()
98
+ logging.info('synthesis text {}'.format(i))
99
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
100
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
101
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
102
+ yield model_output
103
+ start_time = time.time()
104
+
105
+ def inference_cross_lingual(self, tts_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
106
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
107
+ model_input = self.frontend.frontend_cross_lingual(i, prompt_wav, self.sample_rate, zero_shot_spk_id)
108
+ start_time = time.time()
109
+ logging.info('synthesis text {}'.format(i))
110
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
111
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
112
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
113
+ yield model_output
114
+ start_time = time.time()
115
+
116
+ def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
117
+ assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!'
118
+ instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
119
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
120
+ model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
121
+ start_time = time.time()
122
+ logging.info('synthesis text {}'.format(i))
123
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
124
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
125
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
126
+ yield model_output
127
+ start_time = time.time()
128
+
129
+ def inference_vc(self, source_wav, prompt_wav, stream=False, speed=1.0):
130
+ model_input = self.frontend.frontend_vc(source_wav, prompt_wav, self.sample_rate)
131
+ start_time = time.time()
132
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
133
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
134
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
135
+ yield model_output
136
+ start_time = time.time()
137
+
138
+
139
+ class CosyVoice2(CosyVoice):
140
+
141
+ def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
142
+ self.model_dir = model_dir
143
+ self.fp16 = fp16
144
+ if not os.path.exists(model_dir):
145
+ model_dir = snapshot_download(model_dir)
146
+ hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir)
147
+ if not os.path.exists(hyper_yaml_path):
148
+ raise ValueError('{} not found!'.format(hyper_yaml_path))
149
+ with open(hyper_yaml_path, 'r') as f:
150
+ configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
151
+ assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir)
152
+ self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
153
+ configs['feat_extractor'],
154
+ '{}/campplus.onnx'.format(model_dir),
155
+ '{}/speech_tokenizer_v2.onnx'.format(model_dir),
156
+ '{}/spk2info.pt'.format(model_dir),
157
+ configs['allowed_special'])
158
+ self.sample_rate = configs['sample_rate']
159
+ if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or load_vllm is True or fp16 is True):
160
+ load_jit, load_trt, load_vllm, fp16 = False, False, False, False
161
+ logging.warning('no cuda device, set load_jit/load_trt/load_vllm/fp16 to False')
162
+ self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
163
+ self.model.load('{}/llm.pt'.format(model_dir),
164
+ '{}/flow.pt'.format(model_dir),
165
+ '{}/hift.pt'.format(model_dir))
166
+ if load_vllm:
167
+ self.model.load_vllm('{}/vllm'.format(model_dir))
168
+ if load_jit:
169
+ self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
170
+ if load_trt:
171
+ self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
172
+ '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
173
+ trt_concurrent,
174
+ self.fp16)
175
+ del configs
176
+
177
+ def inference_instruct2(self, tts_text, instruct_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
178
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
179
+ model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_wav, self.sample_rate, zero_shot_spk_id)
180
+ start_time = time.time()
181
+ logging.info('synthesis text {}'.format(i))
182
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
183
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
184
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
185
+ yield model_output
186
+ start_time = time.time()
187
+
188
+
189
+ class CosyVoice3(CosyVoice2):
190
+
191
+ def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
192
+ self.model_dir = model_dir
193
+ self.fp16 = fp16
194
+ if not os.path.exists(model_dir):
195
+ model_dir = snapshot_download(model_dir)
196
+ hyper_yaml_path = '{}/cosyvoice3.yaml'.format(model_dir)
197
+ if not os.path.exists(hyper_yaml_path):
198
+ raise ValueError('{} not found!'.format(hyper_yaml_path))
199
+ with open(hyper_yaml_path, 'r') as f:
200
+ configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
201
+ assert get_model_type(configs) == CosyVoice3Model, 'do not use {} for CosyVoice3 initialization!'.format(model_dir)
202
+ self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
203
+ configs['feat_extractor'],
204
+ '{}/campplus.onnx'.format(model_dir),
205
+ '{}/speech_tokenizer_v3.onnx'.format(model_dir),
206
+ '{}/spk2info.pt'.format(model_dir),
207
+ configs['allowed_special'])
208
+ self.sample_rate = configs['sample_rate']
209
+ if torch.cuda.is_available() is False and (load_trt is True or fp16 is True):
210
+ load_trt, fp16 = False, False
211
+ logging.warning('no cuda device, set load_trt/fp16 to False')
212
+ self.model = CosyVoice3Model(configs['llm'], configs['flow'], configs['hift'], fp16)
213
+ self.model.load('{}/llm.pt'.format(model_dir),
214
+ '{}/flow.pt'.format(model_dir),
215
+ '{}/hift.pt'.format(model_dir))
216
+ if load_vllm:
217
+ self.model.load_vllm('{}/vllm'.format(model_dir))
218
+ if load_trt:
219
+ if self.fp16 is True:
220
+ logging.warning('DiT tensorRT fp16 engine have some performance issue, use at caution!')
221
+ self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
222
+ '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
223
+ trt_concurrent,
224
+ self.fp16)
225
+ del configs
226
+
227
+
228
+ def AutoModel(**kwargs):
229
+ if not os.path.exists(kwargs['model_dir']):
230
+ kwargs['model_dir'] = snapshot_download(kwargs['model_dir'])
231
+ if os.path.exists('{}/cosyvoice.yaml'.format(kwargs['model_dir'])):
232
+ return CosyVoice(**kwargs)
233
+ elif os.path.exists('{}/cosyvoice2.yaml'.format(kwargs['model_dir'])):
234
+ return CosyVoice2(**kwargs)
235
+ elif os.path.exists('{}/cosyvoice3.yaml'.format(kwargs['model_dir'])):
236
+ return CosyVoice3(**kwargs)
237
+ else:
238
+ raise TypeError('No valid model type found!')
cosyvoice/cli/frontend.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from functools import partial
15
+ from typing import Generator
16
+ import json
17
+ import onnxruntime
18
+ import torch
19
+ import numpy as np
20
+ import whisper
21
+ from typing import Callable
22
+ import torchaudio.compliance.kaldi as kaldi
23
+ import torchaudio
24
+ import os
25
+ import re
26
+ import inflect
27
+ try:
28
+ import ttsfrd
29
+ use_ttsfrd = True
30
+ except ImportError:
31
+ print("failed to import ttsfrd, use wetext instead")
32
+ from wetext import Normalizer as ZhNormalizer
33
+ from wetext import Normalizer as EnNormalizer
34
+ use_ttsfrd = False
35
+ from cosyvoice.utils.file_utils import logging, load_wav
36
+ from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
37
+
38
+
39
+ class CosyVoiceFrontEnd:
40
+
41
+ def __init__(self,
42
+ get_tokenizer: Callable,
43
+ feat_extractor: Callable,
44
+ campplus_model: str,
45
+ speech_tokenizer_model: str,
46
+ spk2info: str = '',
47
+ allowed_special: str = 'all'):
48
+ self.tokenizer = get_tokenizer()
49
+ self.feat_extractor = feat_extractor
50
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
51
+ option = onnxruntime.SessionOptions()
52
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
53
+ option.intra_op_num_threads = 1
54
+ self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
55
+ self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
56
+ providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
57
+ "CPUExecutionProvider"])
58
+ if os.path.exists(spk2info):
59
+ self.spk2info = torch.load(spk2info, map_location=self.device)
60
+ else:
61
+ self.spk2info = {}
62
+ self.allowed_special = allowed_special
63
+ self.use_ttsfrd = use_ttsfrd
64
+ if self.use_ttsfrd:
65
+ self.frd = ttsfrd.TtsFrontendEngine()
66
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
67
+ assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
68
+ 'failed to initialize ttsfrd resource'
69
+ self.frd.set_lang_type('pinyinvg')
70
+ else:
71
+ self.zh_tn_model = ZhNormalizer(remove_erhua=False)
72
+ self.en_tn_model = EnNormalizer()
73
+ self.inflect_parser = inflect.engine()
74
+
75
+ def _extract_text_token(self, text):
76
+ if isinstance(text, Generator):
77
+ logging.info('get tts_text generator, will return _extract_text_token_generator!')
78
+ # NOTE add a dummy text_token_len for compatibility
79
+ return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)
80
+ else:
81
+ text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
82
+ text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
83
+ text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
84
+ return text_token, text_token_len
85
+
86
+ def _extract_text_token_generator(self, text_generator):
87
+ for text in text_generator:
88
+ text_token, _ = self._extract_text_token(text)
89
+ for i in range(text_token.shape[1]):
90
+ yield text_token[:, i: i + 1]
91
+
92
+ def _extract_speech_token(self, prompt_wav):
93
+ speech = load_wav(prompt_wav, 16000)
94
+ assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
95
+ feat = whisper.log_mel_spectrogram(speech, n_mels=128)
96
+ speech_token = self.speech_tokenizer_session.run(None,
97
+ {self.speech_tokenizer_session.get_inputs()[0].name:
98
+ feat.detach().cpu().numpy(),
99
+ self.speech_tokenizer_session.get_inputs()[1].name:
100
+ np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
101
+ speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
102
+ speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
103
+ return speech_token, speech_token_len
104
+
105
+ def _extract_spk_embedding(self, prompt_wav):
106
+ speech = load_wav(prompt_wav, 16000)
107
+ feat = kaldi.fbank(speech,
108
+ num_mel_bins=80,
109
+ dither=0,
110
+ sample_frequency=16000)
111
+ feat = feat - feat.mean(dim=0, keepdim=True)
112
+ embedding = self.campplus_session.run(None,
113
+ {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
114
+ embedding = torch.tensor([embedding]).to(self.device)
115
+ return embedding
116
+
117
+ def _extract_speech_feat(self, prompt_wav):
118
+ speech = load_wav(prompt_wav, 24000)
119
+ speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
120
+ speech_feat = speech_feat.unsqueeze(dim=0)
121
+ speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
122
+ return speech_feat, speech_feat_len
123
+
124
+ def text_normalize(self, text, split=True, text_frontend=True):
125
+ if isinstance(text, Generator):
126
+ logging.info('get tts_text generator, will skip text_normalize!')
127
+ return [text]
128
+ # NOTE skip text_frontend when ssml symbol in text
129
+ if '<|' in text and '|>' in text:
130
+ text_frontend = False
131
+ if text_frontend is False or text == '':
132
+ return [text] if split is True else text
133
+ text = text.strip()
134
+ if self.use_ttsfrd:
135
+ texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
136
+ text = ''.join(texts)
137
+ else:
138
+ if contains_chinese(text):
139
+ text = self.zh_tn_model.normalize(text)
140
+ text = text.replace("\n", "")
141
+ text = replace_blank(text)
142
+ text = replace_corner_mark(text)
143
+ text = text.replace(".", "。")
144
+ text = text.replace(" - ", ",")
145
+ text = remove_bracket(text)
146
+ text = re.sub(r'[,,、]+$', '。', text)
147
+ texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
148
+ token_min_n=60, merge_len=20, comma_split=False))
149
+ else:
150
+ text = self.en_tn_model.normalize(text)
151
+ text = spell_out_number(text, self.inflect_parser)
152
+ texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
153
+ token_min_n=60, merge_len=20, comma_split=False))
154
+ texts = [i for i in texts if not is_only_punctuation(i)]
155
+ return texts if split is True else text
156
+
157
+ def frontend_sft(self, tts_text, spk_id):
158
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
159
+ embedding = self.spk2info[spk_id]['embedding']
160
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
161
+ return model_input
162
+
163
+ def frontend_zero_shot(self, tts_text, prompt_text, prompt_wav, resample_rate, zero_shot_spk_id):
164
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
165
+ if zero_shot_spk_id == '':
166
+ prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
167
+ speech_feat, speech_feat_len = self._extract_speech_feat(prompt_wav)
168
+ speech_token, speech_token_len = self._extract_speech_token(prompt_wav)
169
+ if resample_rate == 24000:
170
+ # cosyvoice2, force speech_feat % speech_token = 2
171
+ token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
172
+ speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
173
+ speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
174
+ embedding = self._extract_spk_embedding(prompt_wav)
175
+ model_input = {'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
176
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
177
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
178
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
179
+ 'llm_embedding': embedding, 'flow_embedding': embedding}
180
+ else:
181
+ model_input = self.spk2info[zero_shot_spk_id]
182
+ model_input['text'] = tts_text_token
183
+ model_input['text_len'] = tts_text_token_len
184
+ return model_input
185
+
186
+ def frontend_cross_lingual(self, tts_text, prompt_wav, resample_rate, zero_shot_spk_id):
187
+ model_input = self.frontend_zero_shot(tts_text, '', prompt_wav, resample_rate, zero_shot_spk_id)
188
+ # in cross lingual mode, we remove prompt in llm
189
+ del model_input['prompt_text']
190
+ del model_input['prompt_text_len']
191
+ del model_input['llm_prompt_speech_token']
192
+ del model_input['llm_prompt_speech_token_len']
193
+ return model_input
194
+
195
+ def frontend_instruct(self, tts_text, spk_id, instruct_text):
196
+ model_input = self.frontend_sft(tts_text, spk_id)
197
+ # in instruct mode, we remove spk_embedding in llm due to information leakage
198
+ del model_input['llm_embedding']
199
+ instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text)
200
+ model_input['prompt_text'] = instruct_text_token
201
+ model_input['prompt_text_len'] = instruct_text_token_len
202
+ return model_input
203
+
204
+ def frontend_instruct2(self, tts_text, instruct_text, prompt_wav, resample_rate, zero_shot_spk_id):
205
+ model_input = self.frontend_zero_shot(tts_text, instruct_text, prompt_wav, resample_rate, zero_shot_spk_id)
206
+ del model_input['llm_prompt_speech_token']
207
+ del model_input['llm_prompt_speech_token_len']
208
+ return model_input
209
+
210
+ def frontend_vc(self, source_speech_16k, prompt_wav, resample_rate):
211
+ prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_wav)
212
+ prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_wav)
213
+ embedding = self._extract_spk_embedding(prompt_wav)
214
+ source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
215
+ model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
216
+ 'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
217
+ 'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
218
+ 'flow_embedding': embedding}
219
+ return model_input
cosyvoice/cli/model.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ # 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import os
16
+ from typing import Generator
17
+ import torch
18
+ import numpy as np
19
+ import threading
20
+ import time
21
+ from torch.nn import functional as F
22
+ from contextlib import nullcontext
23
+ import uuid
24
+ from cosyvoice.utils.common import fade_in_out
25
+ from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
26
+ from cosyvoice.utils.common import TrtContextWrapper
27
+
28
+
29
+ class CosyVoiceModel:
30
+
31
+ def __init__(self,
32
+ llm: torch.nn.Module,
33
+ flow: torch.nn.Module,
34
+ hift: torch.nn.Module,
35
+ fp16: bool = False):
36
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
37
+ self.llm = llm
38
+ self.flow = flow
39
+ self.hift = hift
40
+ self.fp16 = fp16
41
+ self.token_min_hop_len = 2 * self.flow.input_frame_rate
42
+ self.token_max_hop_len = 4 * self.flow.input_frame_rate
43
+ self.token_overlap_len = 20
44
+ # mel fade in out
45
+ self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
46
+ self.mel_window = np.hamming(2 * self.mel_overlap_len)
47
+ # hift cache
48
+ self.mel_cache_len = 20
49
+ self.source_cache_len = int(self.mel_cache_len * 256)
50
+ # speech fade in out
51
+ self.speech_window = np.hamming(2 * self.source_cache_len)
52
+ # rtf and decoding related
53
+ self.stream_scale_factor = 1
54
+ assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
55
+ self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
56
+ self.lock = threading.Lock()
57
+ # dict used to store session related variable
58
+ self.tts_speech_token_dict = {}
59
+ self.llm_end_dict = {}
60
+ self.mel_overlap_dict = {}
61
+ self.flow_cache_dict = {}
62
+ self.hift_cache_dict = {}
63
+
64
+ def load(self, llm_model, flow_model, hift_model):
65
+ self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
66
+ self.llm.to(self.device).eval()
67
+ self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
68
+ self.flow.to(self.device).eval()
69
+ # in case hift_model is a hifigan model
70
+ hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
71
+ self.hift.load_state_dict(hift_state_dict, strict=True)
72
+ self.hift.to(self.device).eval()
73
+
74
+ def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
75
+ llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
76
+ self.llm.text_encoder = llm_text_encoder
77
+ llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
78
+ self.llm.llm = llm_llm
79
+ flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
80
+ self.flow.encoder = flow_encoder
81
+
82
+ def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
83
+ assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
84
+ if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
85
+ convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
86
+ del self.flow.decoder.estimator
87
+ import tensorrt as trt
88
+ with open(flow_decoder_estimator_model, 'rb') as f:
89
+ estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
90
+ assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
91
+ self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
92
+
93
+ def get_trt_kwargs(self):
94
+ min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
95
+ opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
96
+ max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
97
+ input_names = ["x", "mask", "mu", "cond"]
98
+ return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
99
+
100
+ def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
101
+ with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
102
+ if isinstance(text, Generator):
103
+ assert isinstance(self, CosyVoice2Model) and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2 and do not support vllm!'
104
+ for i in self.llm.inference_bistream(text=text,
105
+ prompt_text=prompt_text.to(self.device),
106
+ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
107
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
108
+ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
109
+ embedding=llm_embedding.to(self.device)):
110
+ self.tts_speech_token_dict[uuid].append(i)
111
+ else:
112
+ for i in self.llm.inference(text=text.to(self.device),
113
+ text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
114
+ prompt_text=prompt_text.to(self.device),
115
+ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
116
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
117
+ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
118
+ embedding=llm_embedding.to(self.device),
119
+ uuid=uuid):
120
+ self.tts_speech_token_dict[uuid].append(i)
121
+ self.llm_end_dict[uuid] = True
122
+
123
+ def vc_job(self, source_speech_token, uuid):
124
+ self.tts_speech_token_dict[uuid] = source_speech_token.flatten().tolist()
125
+ self.llm_end_dict[uuid] = True
126
+
127
+ def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
128
+ with torch.cuda.amp.autocast(self.fp16):
129
+ tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
130
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
131
+ prompt_token=prompt_token.to(self.device),
132
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
133
+ prompt_feat=prompt_feat.to(self.device),
134
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
135
+ embedding=embedding.to(self.device),
136
+ flow_cache=self.flow_cache_dict[uuid])
137
+
138
+ # mel overlap fade in out
139
+ if self.mel_overlap_dict[uuid].shape[2] != 0:
140
+ tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
141
+ # append hift cache
142
+ if self.hift_cache_dict[uuid] is not None:
143
+ hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
144
+ tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
145
+ else:
146
+ hift_cache_source = torch.zeros(1, 1, 0)
147
+ # keep overlap mel and hift cache
148
+ if finalize is False:
149
+ self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
150
+ tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
151
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
152
+ if self.hift_cache_dict[uuid] is not None:
153
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
154
+ self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
155
+ 'source': tts_source[:, :, -self.source_cache_len:],
156
+ 'speech': tts_speech[:, -self.source_cache_len:]}
157
+ tts_speech = tts_speech[:, :-self.source_cache_len]
158
+ else:
159
+ if speed != 1.0:
160
+ assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
161
+ tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
162
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
163
+ if self.hift_cache_dict[uuid] is not None:
164
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
165
+ return tts_speech
166
+
167
+ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.zeros(0, 192), llm_embedding=torch.zeros(0, 192),
168
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32),
169
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
170
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
171
+ prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
172
+ # this_uuid is used to track variables related to this inference thread
173
+ this_uuid = str(uuid.uuid1())
174
+ with self.lock:
175
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
176
+ self.hift_cache_dict[this_uuid] = None
177
+ self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
178
+ self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
179
+ if source_speech_token.shape[1] == 0:
180
+ p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
181
+ else:
182
+ p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
183
+ p.start()
184
+ if stream is True:
185
+ token_hop_len = self.token_min_hop_len
186
+ while True:
187
+ time.sleep(0.1)
188
+ if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
189
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
190
+ .unsqueeze(dim=0)
191
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
192
+ prompt_token=flow_prompt_speech_token,
193
+ prompt_feat=prompt_speech_feat,
194
+ embedding=flow_embedding,
195
+ uuid=this_uuid,
196
+ finalize=False)
197
+ yield {'tts_speech': this_tts_speech.cpu()}
198
+ with self.lock:
199
+ self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
200
+ # increase token_hop_len for better speech quality
201
+ token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
202
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
203
+ break
204
+ p.join()
205
+ # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
206
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
207
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
208
+ prompt_token=flow_prompt_speech_token,
209
+ prompt_feat=prompt_speech_feat,
210
+ embedding=flow_embedding,
211
+ uuid=this_uuid,
212
+ finalize=True)
213
+ yield {'tts_speech': this_tts_speech.cpu()}
214
+ else:
215
+ # deal with all tokens
216
+ p.join()
217
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
218
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
219
+ prompt_token=flow_prompt_speech_token,
220
+ prompt_feat=prompt_speech_feat,
221
+ embedding=flow_embedding,
222
+ uuid=this_uuid,
223
+ finalize=True,
224
+ speed=speed)
225
+ yield {'tts_speech': this_tts_speech.cpu()}
226
+ with self.lock:
227
+ self.tts_speech_token_dict.pop(this_uuid)
228
+ self.llm_end_dict.pop(this_uuid)
229
+ self.mel_overlap_dict.pop(this_uuid)
230
+ self.hift_cache_dict.pop(this_uuid)
231
+ self.flow_cache_dict.pop(this_uuid)
232
+ if torch.cuda.is_available():
233
+ torch.cuda.empty_cache()
234
+ torch.cuda.current_stream().synchronize()
235
+
236
+
237
+ class CosyVoice2Model(CosyVoiceModel):
238
+
239
+ def __init__(self,
240
+ llm: torch.nn.Module,
241
+ flow: torch.nn.Module,
242
+ hift: torch.nn.Module,
243
+ fp16: bool = False):
244
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
245
+ self.llm = llm
246
+ self.flow = flow
247
+ self.hift = hift
248
+ self.fp16 = fp16
249
+ # NOTE must matching training static_chunk_size
250
+ self.token_hop_len = 25
251
+ # hift cache
252
+ self.mel_cache_len = 8
253
+ self.source_cache_len = int(self.mel_cache_len * 480)
254
+ # speech fade in out
255
+ self.speech_window = np.hamming(2 * self.source_cache_len)
256
+ # rtf and decoding related
257
+ self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
258
+ self.lock = threading.Lock()
259
+ # dict used to store session related variable
260
+ self.tts_speech_token_dict = {}
261
+ self.llm_end_dict = {}
262
+ self.hift_cache_dict = {}
263
+
264
+ def load_jit(self, flow_encoder_model):
265
+ flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
266
+ self.flow.encoder = flow_encoder
267
+
268
+ def load_vllm(self, model_dir):
269
+ export_cosyvoice2_vllm(self.llm, model_dir, self.device)
270
+ from vllm import EngineArgs, LLMEngine
271
+ engine_args = EngineArgs(model=model_dir,
272
+ skip_tokenizer_init=True,
273
+ enable_prompt_embeds=True,
274
+ gpu_memory_utilization=0.2)
275
+ self.llm.vllm = LLMEngine.from_engine_args(engine_args)
276
+ self.llm.lock = threading.Lock()
277
+ del self.llm.llm.model.model.layers
278
+
279
+ def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
280
+ with torch.cuda.amp.autocast(self.fp16):
281
+ tts_mel, _ = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
282
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
283
+ prompt_token=prompt_token.to(self.device),
284
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
285
+ prompt_feat=prompt_feat.to(self.device),
286
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
287
+ embedding=embedding.to(self.device),
288
+ streaming=stream,
289
+ finalize=finalize)
290
+ tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
291
+ # append hift cache
292
+ if self.hift_cache_dict[uuid] is not None:
293
+ hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
294
+ tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
295
+ else:
296
+ hift_cache_source = torch.zeros(1, 1, 0)
297
+ # keep overlap mel and hift cache
298
+ if finalize is False:
299
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
300
+ if self.hift_cache_dict[uuid] is not None:
301
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
302
+ self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
303
+ 'source': tts_source[:, :, -self.source_cache_len:],
304
+ 'speech': tts_speech[:, -self.source_cache_len:]}
305
+ tts_speech = tts_speech[:, :-self.source_cache_len]
306
+ else:
307
+ if speed != 1.0:
308
+ assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
309
+ tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
310
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
311
+ if self.hift_cache_dict[uuid] is not None:
312
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
313
+ return tts_speech
314
+
315
+ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.zeros(0, 192), llm_embedding=torch.zeros(0, 192),
316
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32),
317
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
318
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
319
+ prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
320
+ # this_uuid is used to track variables related to this inference thread
321
+ this_uuid = str(uuid.uuid1())
322
+ with self.lock:
323
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
324
+ self.hift_cache_dict[this_uuid] = None
325
+ if source_speech_token.shape[1] == 0:
326
+ p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
327
+ else:
328
+ p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
329
+ p.start()
330
+ if stream is True:
331
+ token_offset = 0
332
+ prompt_token_pad = int(np.ceil(flow_prompt_speech_token.shape[1] / self.token_hop_len) * self.token_hop_len - flow_prompt_speech_token.shape[1])
333
+ while True:
334
+ time.sleep(0.1)
335
+ this_token_hop_len = self.token_hop_len + prompt_token_pad if token_offset == 0 else self.token_hop_len
336
+ if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= this_token_hop_len + self.flow.pre_lookahead_len:
337
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + this_token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
338
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
339
+ prompt_token=flow_prompt_speech_token,
340
+ prompt_feat=prompt_speech_feat,
341
+ embedding=flow_embedding,
342
+ token_offset=token_offset,
343
+ uuid=this_uuid,
344
+ stream=stream,
345
+ finalize=False)
346
+ token_offset += this_token_hop_len
347
+ yield {'tts_speech': this_tts_speech.cpu()}
348
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len:
349
+ break
350
+ p.join()
351
+ # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
352
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
353
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
354
+ prompt_token=flow_prompt_speech_token,
355
+ prompt_feat=prompt_speech_feat,
356
+ embedding=flow_embedding,
357
+ token_offset=token_offset,
358
+ uuid=this_uuid,
359
+ finalize=True)
360
+ yield {'tts_speech': this_tts_speech.cpu()}
361
+ else:
362
+ # deal with all tokens
363
+ p.join()
364
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
365
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
366
+ prompt_token=flow_prompt_speech_token,
367
+ prompt_feat=prompt_speech_feat,
368
+ embedding=flow_embedding,
369
+ token_offset=0,
370
+ uuid=this_uuid,
371
+ finalize=True,
372
+ speed=speed)
373
+ yield {'tts_speech': this_tts_speech.cpu()}
374
+ with self.lock:
375
+ self.tts_speech_token_dict.pop(this_uuid)
376
+ self.llm_end_dict.pop(this_uuid)
377
+ self.hift_cache_dict.pop(this_uuid)
378
+ if torch.cuda.is_available():
379
+ torch.cuda.empty_cache()
380
+ torch.cuda.current_stream().synchronize()
381
+
382
+
383
+ class CosyVoice3Model(CosyVoice2Model):
384
+
385
+ def __init__(self,
386
+ llm: torch.nn.Module,
387
+ flow: torch.nn.Module,
388
+ hift: torch.nn.Module,
389
+ fp16: bool = False):
390
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
391
+ self.llm = llm
392
+ self.flow = flow
393
+ self.hift = hift
394
+ self.fp16 = fp16
395
+ # NOTE must matching training static_chunk_size
396
+ self.token_hop_len = 25
397
+ # rtf and decoding related
398
+ self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
399
+ self.lock = threading.Lock()
400
+ # dict used to store session related variable
401
+ self.tts_speech_token_dict = {}
402
+ self.llm_end_dict = {}
403
+ self.hift_cache_dict = {}
404
+
405
+ def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
406
+ with torch.cuda.amp.autocast(self.fp16):
407
+ tts_mel, _ = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
408
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
409
+ prompt_token=prompt_token.to(self.device),
410
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
411
+ prompt_feat=prompt_feat.to(self.device),
412
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
413
+ embedding=embedding.to(self.device),
414
+ streaming=stream,
415
+ finalize=finalize)
416
+ tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
417
+ # append mel cache
418
+ if self.hift_cache_dict[uuid] is not None:
419
+ hift_cache_mel = self.hift_cache_dict[uuid]['mel']
420
+ tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
421
+ self.hift_cache_dict[uuid]['mel'] = tts_mel
422
+ else:
423
+ self.hift_cache_dict[uuid] = {'mel': tts_mel, 'speech_offset': 0}
424
+ if speed != 1.0:
425
+ assert token_offset == 0 and finalize is True, 'speed change only support non-stream inference mode'
426
+ tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
427
+ tts_speech, _ = self.hift.inference(speech_feat=tts_mel, finalize=finalize)
428
+ tts_speech = tts_speech[:, self.hift_cache_dict[uuid]['speech_offset']:]
429
+ self.hift_cache_dict[uuid]['speech_offset'] += tts_speech.shape[1]
430
+ return tts_speech
cosyvoice/dataset/__init__.py ADDED
File without changes
cosyvoice/dataset/dataset.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import random
17
+ import math
18
+ from functools import partial
19
+
20
+ import torch
21
+ import torch.distributed as dist
22
+ from torch.utils.data import IterableDataset
23
+ from cosyvoice.utils.file_utils import read_lists
24
+
25
+
26
+ class Processor(IterableDataset):
27
+
28
+ def __init__(self, source, f, *args, **kw):
29
+ assert callable(f)
30
+ self.source = source
31
+ self.f = f
32
+ self.args = args
33
+ self.kw = kw
34
+
35
+ def set_epoch(self, epoch):
36
+ self.source.set_epoch(epoch)
37
+
38
+ def __iter__(self):
39
+ """ Return an iterator over the source dataset processed by the
40
+ given processor.
41
+ """
42
+ assert self.source is not None
43
+ assert callable(self.f)
44
+ return self.f(iter(self.source), *self.args, **self.kw)
45
+
46
+ def apply(self, f):
47
+ assert callable(f)
48
+ return Processor(self, f, *self.args, **self.kw)
49
+
50
+
51
+ class DistributedSampler:
52
+
53
+ def __init__(self, shuffle=True, partition=True):
54
+ self.epoch = -1
55
+ self.update()
56
+ self.shuffle = shuffle
57
+ self.partition = partition
58
+
59
+ def update(self):
60
+ assert dist.is_available()
61
+ if dist.is_initialized():
62
+ self.rank = dist.get_rank()
63
+ self.world_size = dist.get_world_size()
64
+ else:
65
+ self.rank = 0
66
+ self.world_size = 1
67
+ worker_info = torch.utils.data.get_worker_info()
68
+ if worker_info is None:
69
+ self.worker_id = 0
70
+ self.num_workers = 1
71
+ else:
72
+ self.worker_id = worker_info.id
73
+ self.num_workers = worker_info.num_workers
74
+ return dict(rank=self.rank,
75
+ world_size=self.world_size,
76
+ worker_id=self.worker_id,
77
+ num_workers=self.num_workers)
78
+
79
+ def set_epoch(self, epoch):
80
+ self.epoch = epoch
81
+
82
+ def sample(self, data):
83
+ """ Sample data according to rank/world_size/num_workers
84
+
85
+ Args:
86
+ data(List): input data list
87
+
88
+ Returns:
89
+ List: data list after sample
90
+ """
91
+ data = list(range(len(data)))
92
+ # force datalist even
93
+ if self.partition:
94
+ if self.shuffle:
95
+ random.Random(self.epoch).shuffle(data)
96
+ if len(data) < self.world_size:
97
+ data = data * math.ceil(self.world_size / len(data))
98
+ data = data[:self.world_size]
99
+ data = data[self.rank::self.world_size]
100
+ if len(data) < self.num_workers:
101
+ data = data * math.ceil(self.num_workers / len(data))
102
+ data = data[:self.num_workers]
103
+ data = data[self.worker_id::self.num_workers]
104
+ return data
105
+
106
+
107
+ class DataList(IterableDataset):
108
+
109
+ def __init__(self, lists, shuffle=True, partition=True):
110
+ self.lists = lists
111
+ self.sampler = DistributedSampler(shuffle, partition)
112
+
113
+ def set_epoch(self, epoch):
114
+ self.sampler.set_epoch(epoch)
115
+
116
+ def __iter__(self):
117
+ sampler_info = self.sampler.update()
118
+ indexes = self.sampler.sample(self.lists)
119
+ for index in indexes:
120
+ data = dict(src=self.lists[index])
121
+ data.update(sampler_info)
122
+ yield data
123
+
124
+
125
+ def Dataset(data_list_file,
126
+ data_pipeline,
127
+ mode='train',
128
+ gan=False,
129
+ dpo=False,
130
+ shuffle=True,
131
+ partition=True):
132
+ """ Construct dataset from arguments
133
+
134
+ We have two shuffle stage in the Dataset. The first is global
135
+ shuffle at shards tar/raw file level. The second is global shuffle
136
+ at training samples level.
137
+
138
+ Args:
139
+ data_type(str): raw/shard
140
+ tokenizer (BaseTokenizer): tokenizer to tokenize
141
+ partition(bool): whether to do data partition in terms of rank
142
+ """
143
+ lists = read_lists(data_list_file)
144
+ dataset = DataList(lists,
145
+ shuffle=shuffle,
146
+ partition=partition)
147
+ # map partial arg to padding func
148
+ data_pipeline[-1] = partial(data_pipeline[-1], gan=gan, dpo=dpo)
149
+ for func in data_pipeline:
150
+ dataset = Processor(dataset, func, mode=mode)
151
+ return dataset
cosyvoice/dataset/processor.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import random
16
+
17
+ import pyarrow.parquet as pq
18
+ from io import BytesIO
19
+ import torch
20
+ import torchaudio
21
+ from torch.nn.utils.rnn import pad_sequence
22
+ import torch.nn.functional as F
23
+ import pyworld as pw
24
+
25
+
26
+ AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
27
+
28
+
29
+ def parquet_opener(data, mode='train', tts_data={}):
30
+ """ Give url or local file, return file descriptor
31
+ Inplace operation.
32
+
33
+ Args:
34
+ data(Iterable[str]): url or local file list
35
+
36
+ Returns:
37
+ Iterable[{src, stream}]
38
+ """
39
+ for sample in data:
40
+ assert 'src' in sample
41
+ url = sample['src']
42
+ try:
43
+ for df in pq.ParquetFile(url).iter_batches(batch_size=64):
44
+ df = df.to_pandas()
45
+ for i in range(len(df)):
46
+ sample.update(dict(df.loc[i]))
47
+ if mode == 'train':
48
+ # NOTE do not return sample directly, must initialize a new dict
49
+ yield {**sample}
50
+ else:
51
+ for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
52
+ yield {**sample, 'tts_index': index, 'tts_text': text}
53
+ except Exception as ex:
54
+ logging.warning('Failed to open {}, ex info {}'.format(url, ex))
55
+
56
+
57
+ def filter(data,
58
+ max_length=10240,
59
+ min_length=10,
60
+ token_max_length=200,
61
+ token_min_length=1,
62
+ min_output_input_ratio=0.0005,
63
+ max_output_input_ratio=1,
64
+ mode='train'):
65
+ """ Filter sample according to feature and label length
66
+ Inplace operation.
67
+
68
+ Args::
69
+ data: Iterable[{key, wav, label, sample_rate}]
70
+ max_length: drop utterance which is greater than max_length(10ms)
71
+ min_length: drop utterance which is less than min_length(10ms)
72
+ token_max_length: drop utterance which is greater than
73
+ token_max_length, especially when use char unit for
74
+ english modeling
75
+ token_min_length: drop utterance which is
76
+ less than token_max_length
77
+ min_output_input_ratio: minimal ration of
78
+ token_length / feats_length(10ms)
79
+ max_output_input_ratio: maximum ration of
80
+ token_length / feats_length(10ms)
81
+
82
+ Returns:
83
+ Iterable[{key, wav, label, sample_rate}]
84
+ """
85
+ for sample in data:
86
+ sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
87
+ sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
88
+ del sample['audio_data']
89
+ # sample['wav'] is torch.Tensor, we have 100 frames every second
90
+ num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
91
+ if num_frames < min_length:
92
+ continue
93
+ if num_frames > max_length:
94
+ continue
95
+ if len(sample['text_token']) < token_min_length:
96
+ continue
97
+ if len(sample['text_token']) > token_max_length:
98
+ continue
99
+ if len(sample['speech_token']) == 0:
100
+ continue
101
+ if 'reject_speech_token' in sample and len(sample['reject_speech_token']) == 0:
102
+ continue
103
+ if num_frames != 0:
104
+ if len(sample['text_token']) / num_frames < min_output_input_ratio:
105
+ continue
106
+ if len(sample['text_token']) / num_frames > max_output_input_ratio:
107
+ continue
108
+ yield sample
109
+
110
+
111
+ def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
112
+ """ Resample data.
113
+ Inplace operation.
114
+
115
+ Args:
116
+ data: Iterable[{key, wav, label, sample_rate}]
117
+ resample_rate: target resample rate
118
+
119
+ Returns:
120
+ Iterable[{key, wav, label, sample_rate}]
121
+ """
122
+ for sample in data:
123
+ assert 'sample_rate' in sample
124
+ assert 'speech' in sample
125
+ sample_rate = sample['sample_rate']
126
+ waveform = sample['speech']
127
+ if sample_rate != resample_rate:
128
+ if sample_rate < min_sample_rate:
129
+ continue
130
+ sample['sample_rate'] = resample_rate
131
+ sample['speech'] = torchaudio.transforms.Resample(
132
+ orig_freq=sample_rate, new_freq=resample_rate)(waveform)
133
+ max_val = sample['speech'].abs().max()
134
+ if max_val > 1:
135
+ sample['speech'] /= max_val
136
+ yield sample
137
+
138
+
139
+ def truncate(data, truncate_length=24576, mode='train'):
140
+ """ Truncate data.
141
+
142
+ Args:
143
+ data: Iterable[{key, wav, label, sample_rate}]
144
+ truncate_length: truncate length
145
+
146
+ Returns:
147
+ Iterable[{key, wav, label, sample_rate}]
148
+ """
149
+ for sample in data:
150
+ waveform = sample['speech']
151
+ if waveform.shape[1] > truncate_length:
152
+ start = random.randint(0, waveform.shape[1] - truncate_length)
153
+ waveform = waveform[:, start: start + truncate_length]
154
+ else:
155
+ waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
156
+ sample['speech'] = waveform
157
+ yield sample
158
+
159
+
160
+ def compute_fbank(data,
161
+ feat_extractor,
162
+ token_mel_ratio=0,
163
+ mode='train'):
164
+ """ Extract fbank
165
+
166
+ Args:
167
+ data: Iterable[{key, wav, label, sample_rate}]
168
+
169
+ Returns:
170
+ Iterable[{key, feat, label}]
171
+ """
172
+ for sample in data:
173
+ assert 'sample_rate' in sample
174
+ assert 'speech' in sample
175
+ assert 'utt' in sample
176
+ assert 'text_token' in sample
177
+ waveform = sample['speech']
178
+ feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
179
+ if token_mel_ratio != 0:
180
+ # trim to align speech_token and speech_feat
181
+ token_len = int(min(feat.shape[0] / token_mel_ratio, sample["speech_token"].shape[0]))
182
+ feat = feat[:token_mel_ratio * token_len]
183
+ sample["speech_token"] = sample["speech_token"][:token_len]
184
+ sample['speech_feat'] = feat
185
+ yield sample
186
+
187
+
188
+ def compute_f0(data, sample_rate, hop_size, mode='train'):
189
+ """ Extract f0
190
+
191
+ Args:
192
+ data: Iterable[{key, wav, label, sample_rate}]
193
+
194
+ Returns:
195
+ Iterable[{key, feat, label}]
196
+ """
197
+ frame_period = hop_size * 1000 / sample_rate
198
+ for sample in data:
199
+ assert 'sample_rate' in sample
200
+ assert 'speech' in sample
201
+ assert 'utt' in sample
202
+ assert 'text_token' in sample
203
+ waveform = sample['speech']
204
+ _f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)
205
+ if sum(_f0 != 0) < 5: # this happens when the algorithm fails
206
+ _f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio
207
+ f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate)
208
+ f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1)
209
+ sample['pitch_feat'] = f0
210
+ yield sample
211
+
212
+
213
+ def parse_embedding(data, normalize, mode='train'):
214
+ """ Parse utt_embedding/spk_embedding
215
+
216
+ Args:
217
+ data: Iterable[{key, wav, label, sample_rate}]
218
+
219
+ Returns:
220
+ Iterable[{key, feat, label}]
221
+ """
222
+ for sample in data:
223
+ sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
224
+ sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
225
+ if normalize:
226
+ sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
227
+ sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
228
+ yield sample
229
+
230
+
231
+ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
232
+ """ Decode text to chars or BPE
233
+ Inplace operation
234
+
235
+ Args:
236
+ data: Iterable[{key, wav, txt, sample_rate}]
237
+
238
+ Returns:
239
+ Iterable[{key, wav, txt, tokens, label, sample_rate}]
240
+ """
241
+ tokenizer = get_tokenizer()
242
+ for sample in data:
243
+ assert 'text' in sample
244
+ sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
245
+ if 'instruct' in sample:
246
+ sample['instruct_token'] = tokenizer.encode(sample['instruct'], allowed_special=allowed_special)
247
+ else:
248
+ sample['instruct_token'] = tokenizer.encode('', allowed_special=allowed_special)
249
+ yield sample
250
+
251
+
252
+ def shuffle(data, shuffle_size=10000, mode='train'):
253
+ """ Local shuffle the data
254
+
255
+ Args:
256
+ data: Iterable[{key, feat, label}]
257
+ shuffle_size: buffer size for shuffle
258
+
259
+ Returns:
260
+ Iterable[{key, feat, label}]
261
+ """
262
+ buf = []
263
+ for sample in data:
264
+ buf.append(sample)
265
+ if len(buf) >= shuffle_size:
266
+ random.shuffle(buf)
267
+ for x in buf:
268
+ yield x
269
+ buf = []
270
+ # The sample left over
271
+ random.shuffle(buf)
272
+ for x in buf:
273
+ yield x
274
+
275
+
276
+ def sort(data, sort_size=500, mode='train'):
277
+ """ Sort the data by feature length.
278
+ Sort is used after shuffle and before batch, so we can group
279
+ utts with similar lengths into a batch, and `sort_size` should
280
+ be less than `shuffle_size`
281
+
282
+ Args:
283
+ data: Iterable[{key, feat, label}]
284
+ sort_size: buffer size for sort
285
+
286
+ Returns:
287
+ Iterable[{key, feat, label}]
288
+ """
289
+
290
+ buf = []
291
+ for sample in data:
292
+ buf.append(sample)
293
+ if len(buf) >= sort_size:
294
+ buf.sort(key=lambda x: x['speech_feat'].size(0))
295
+ for x in buf:
296
+ yield x
297
+ buf = []
298
+ # The sample left over
299
+ buf.sort(key=lambda x: x['speech_feat'].size(0))
300
+ for x in buf:
301
+ yield x
302
+
303
+
304
+ def static_batch(data, batch_size=16):
305
+ """ Static batch the data by `batch_size`
306
+
307
+ Args:
308
+ data: Iterable[{key, feat, label}]
309
+ batch_size: batch size
310
+
311
+ Returns:
312
+ Iterable[List[{key, feat, label}]]
313
+ """
314
+ buf = []
315
+ for sample in data:
316
+ buf.append(sample)
317
+ if len(buf) >= batch_size:
318
+ yield buf
319
+ buf = []
320
+ if len(buf) > 0:
321
+ yield buf
322
+
323
+
324
+ def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
325
+ """ Dynamic batch the data until the total frames in batch
326
+ reach `max_frames_in_batch`
327
+
328
+ Args:
329
+ data: Iterable[{key, feat, label}]
330
+ max_frames_in_batch: max_frames in one batch
331
+
332
+ Returns:
333
+ Iterable[List[{key, feat, label}]]
334
+ """
335
+ buf = []
336
+ longest_frames = 0
337
+ for sample in data:
338
+ assert 'speech_feat' in sample
339
+ assert isinstance(sample['speech_feat'], torch.Tensor)
340
+ new_sample_frames = sample['speech_feat'].size(0)
341
+ longest_frames = max(longest_frames, new_sample_frames)
342
+ frames_after_padding = longest_frames * (len(buf) + 1)
343
+ if frames_after_padding > max_frames_in_batch:
344
+ yield buf
345
+ buf = [sample]
346
+ longest_frames = new_sample_frames
347
+ else:
348
+ buf.append(sample)
349
+ if len(buf) > 0:
350
+ yield buf
351
+
352
+
353
+ def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
354
+ """ Wrapper for static/dynamic batch
355
+ """
356
+ if batch_type == 'static':
357
+ return static_batch(data, batch_size)
358
+ elif batch_type == 'dynamic':
359
+ return dynamic_batch(data, max_frames_in_batch)
360
+ else:
361
+ logging.fatal('Unsupported batch type {}'.format(batch_type))
362
+
363
+
364
+ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
365
+ """ Padding the data into training data
366
+
367
+ Args:
368
+ data: Iterable[List[{key, feat, label}]]
369
+
370
+ Returns:
371
+ Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
372
+ """
373
+ for sample in data:
374
+ assert isinstance(sample, list)
375
+ speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
376
+ dtype=torch.int32)
377
+ order = torch.argsort(speech_feat_len, descending=True)
378
+
379
+ utts = [sample[i]['utt'] for i in order]
380
+ speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
381
+ speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
382
+ speech = pad_sequence(speech, batch_first=True, padding_value=0)
383
+ speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
384
+ speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
385
+ speech_token = pad_sequence(speech_token,
386
+ batch_first=True,
387
+ padding_value=0)
388
+ speech_feat = [sample[i]['speech_feat'] for i in order]
389
+ speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
390
+ speech_feat = pad_sequence(speech_feat,
391
+ batch_first=True,
392
+ padding_value=0)
393
+ text = [sample[i]['text'] for i in order]
394
+ text_token = [torch.tensor(sample[i]['text_token']) for i in order]
395
+ text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
396
+ text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
397
+ instruct_token = [torch.tensor(sample[i]['instruct_token']) for i in order]
398
+ instruct_token_len = torch.tensor([i.size(0) for i in instruct_token], dtype=torch.int32)
399
+ instruct_token = pad_sequence(instruct_token, batch_first=True, padding_value=0)
400
+ utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
401
+ spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
402
+ batch = {
403
+ "utts": utts,
404
+ "speech": speech,
405
+ "speech_len": speech_len,
406
+ "speech_token": speech_token,
407
+ "speech_token_len": speech_token_len,
408
+ "speech_feat": speech_feat,
409
+ "speech_feat_len": speech_feat_len,
410
+ "text": text,
411
+ "text_token": text_token,
412
+ "text_token_len": text_token_len,
413
+ "instruct_token": instruct_token,
414
+ "instruct_token_len": instruct_token_len,
415
+ "utt_embedding": utt_embedding,
416
+ "spk_embedding": spk_embedding,
417
+ }
418
+ if gan is True:
419
+ # in gan train, we need pitch_feat
420
+ pitch_feat = [sample[i]['pitch_feat'] for i in order]
421
+ pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
422
+ pitch_feat = pad_sequence(pitch_feat,
423
+ batch_first=True,
424
+ padding_value=0)
425
+ batch["pitch_feat"] = pitch_feat
426
+ batch["pitch_feat_len"] = pitch_feat_len
427
+ else:
428
+ # only gan train needs speech, delete it to save memory
429
+ del batch["speech"]
430
+ del batch["speech_len"]
431
+ if dpo is True:
432
+ reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order]
433
+ reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
434
+ reject_speech_token = pad_sequence(reject_speech_token,
435
+ batch_first=True,
436
+ padding_value=0)
437
+ batch['reject_speech_token'] = reject_speech_token
438
+ batch['reject_speech_token_len'] = reject_speech_token_len
439
+ if use_spk_embedding is True:
440
+ batch["embedding"] = batch["spk_embedding"]
441
+ else:
442
+ batch["embedding"] = batch["utt_embedding"]
443
+ yield batch
cosyvoice/flow/DiT/dit.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ ein notation:
4
+ b - batch
5
+ n - sequence
6
+ nt - text sequence
7
+ nw - raw wave length
8
+ d - dimension
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import torch
14
+ from torch import nn
15
+ import torch.nn.functional as F
16
+ from einops import repeat
17
+ from x_transformers.x_transformers import RotaryEmbedding
18
+ from cosyvoice.utils.mask import add_optional_chunk_mask
19
+ from cosyvoice.flow.DiT.modules import (
20
+ TimestepEmbedding,
21
+ ConvNeXtV2Block,
22
+ CausalConvPositionEmbedding,
23
+ DiTBlock,
24
+ AdaLayerNormZero_Final,
25
+ precompute_freqs_cis,
26
+ get_pos_embed_indices,
27
+ )
28
+
29
+
30
+ # Text embedding
31
+
32
+
33
+ class TextEmbedding(nn.Module):
34
+ def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
35
+ super().__init__()
36
+ self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
37
+
38
+ if conv_layers > 0:
39
+ self.extra_modeling = True
40
+ self.precompute_max_pos = 4096 # ~44s of 24khz audio
41
+ self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
42
+ self.text_blocks = nn.Sequential(
43
+ *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
44
+ )
45
+ else:
46
+ self.extra_modeling = False
47
+
48
+ def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
49
+ batch, text_len = text.shape[0], text.shape[1]
50
+ text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
51
+ text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
52
+ text = F.pad(text, (0, seq_len - text_len), value=0)
53
+
54
+ if drop_text: # cfg for text
55
+ text = torch.zeros_like(text)
56
+
57
+ text = self.text_embed(text) # b n -> b n d
58
+
59
+ # possible extra modeling
60
+ if self.extra_modeling:
61
+ # sinus pos emb
62
+ batch_start = torch.zeros((batch,), dtype=torch.long)
63
+ pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
64
+ text_pos_embed = self.freqs_cis[pos_idx]
65
+ text = text + text_pos_embed
66
+
67
+ # convnextv2 blocks
68
+ text = self.text_blocks(text)
69
+
70
+ return text
71
+
72
+
73
+ # noised input audio and context mixing embedding
74
+
75
+
76
+ class InputEmbedding(nn.Module):
77
+ def __init__(self, mel_dim, text_dim, out_dim, spk_dim=None):
78
+ super().__init__()
79
+ spk_dim = 0 if spk_dim is None else spk_dim
80
+ self.spk_dim = spk_dim
81
+ self.proj = nn.Linear(mel_dim * 2 + text_dim + spk_dim, out_dim)
82
+ self.conv_pos_embed = CausalConvPositionEmbedding(dim=out_dim)
83
+
84
+ def forward(
85
+ self,
86
+ x: float["b n d"],
87
+ cond: float["b n d"],
88
+ text_embed: float["b n d"],
89
+ spks: float["b d"],
90
+ ):
91
+ to_cat = [x, cond, text_embed]
92
+ if self.spk_dim > 0:
93
+ spks = repeat(spks, "b c -> b t c", t=x.shape[1])
94
+ to_cat.append(spks)
95
+
96
+ x = self.proj(torch.cat(to_cat, dim=-1))
97
+ x = self.conv_pos_embed(x) + x
98
+ return x
99
+
100
+
101
+ # Transformer backbone using DiT blocks
102
+
103
+
104
+ class DiT(nn.Module):
105
+ def __init__(
106
+ self,
107
+ *,
108
+ dim,
109
+ depth=8,
110
+ heads=8,
111
+ dim_head=64,
112
+ dropout=0.1,
113
+ ff_mult=4,
114
+ mel_dim=80,
115
+ mu_dim=None,
116
+ long_skip_connection=False,
117
+ spk_dim=None,
118
+ out_channels=None,
119
+ static_chunk_size=50,
120
+ num_decoding_left_chunks=2
121
+ ):
122
+ super().__init__()
123
+
124
+ self.time_embed = TimestepEmbedding(dim)
125
+ if mu_dim is None:
126
+ mu_dim = mel_dim
127
+ self.input_embed = InputEmbedding(mel_dim, mu_dim, dim, spk_dim)
128
+
129
+ self.rotary_embed = RotaryEmbedding(dim_head)
130
+
131
+ self.dim = dim
132
+ self.depth = depth
133
+
134
+ self.transformer_blocks = nn.ModuleList(
135
+ [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
136
+ )
137
+ self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
138
+
139
+ self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
140
+ self.proj_out = nn.Linear(dim, mel_dim)
141
+ self.out_channels = out_channels
142
+ self.static_chunk_size = static_chunk_size
143
+ self.num_decoding_left_chunks = num_decoding_left_chunks
144
+
145
+ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
146
+ x = x.transpose(1, 2)
147
+ mu = mu.transpose(1, 2)
148
+ cond = cond.transpose(1, 2)
149
+ spks = spks.unsqueeze(dim=1)
150
+ batch, seq_len = x.shape[0], x.shape[1]
151
+ if t.ndim == 0:
152
+ t = t.repeat(batch)
153
+
154
+ # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
155
+ t = self.time_embed(t)
156
+ x = self.input_embed(x, cond, mu, spks.squeeze(1))
157
+
158
+ rope = self.rotary_embed.forward_from_seq_len(seq_len)
159
+
160
+ if self.long_skip_connection is not None:
161
+ residual = x
162
+
163
+ if streaming is True:
164
+ attn_mask = add_optional_chunk_mask(x, mask.bool(), False, False, 0, self.static_chunk_size, -1).unsqueeze(dim=1)
165
+ else:
166
+ attn_mask = add_optional_chunk_mask(x, mask.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1).unsqueeze(dim=1)
167
+
168
+ for block in self.transformer_blocks:
169
+ x = block(x, t, mask=attn_mask.bool(), rope=rope)
170
+
171
+ if self.long_skip_connection is not None:
172
+ x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
173
+
174
+ x = self.norm_out(x, t)
175
+ output = self.proj_out(x).transpose(1, 2)
176
+ return output
cosyvoice/flow/DiT/modules.py ADDED
@@ -0,0 +1,616 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ ein notation:
4
+ b - batch
5
+ n - sequence
6
+ nt - text sequence
7
+ nw - raw wave length
8
+ d - dimension
9
+ """
10
+
11
+ from __future__ import annotations
12
+ from typing import Optional
13
+ import math
14
+
15
+ import torch
16
+ from torch import nn
17
+ import torch.nn.functional as F
18
+ import torchaudio
19
+
20
+ from x_transformers.x_transformers import apply_rotary_pos_emb
21
+
22
+
23
+ # raw wav to mel spec
24
+ class MelSpec(nn.Module):
25
+ def __init__(
26
+ self,
27
+ filter_length=1024,
28
+ hop_length=256,
29
+ win_length=1024,
30
+ n_mel_channels=100,
31
+ target_sample_rate=24_000,
32
+ normalize=False,
33
+ power=1,
34
+ norm=None,
35
+ center=True,
36
+ ):
37
+ super().__init__()
38
+ self.n_mel_channels = n_mel_channels
39
+
40
+ self.mel_stft = torchaudio.transforms.MelSpectrogram(
41
+ sample_rate=target_sample_rate,
42
+ n_fft=filter_length,
43
+ win_length=win_length,
44
+ hop_length=hop_length,
45
+ n_mels=n_mel_channels,
46
+ power=power,
47
+ center=center,
48
+ normalized=normalize,
49
+ norm=norm,
50
+ )
51
+
52
+ self.register_buffer("dummy", torch.tensor(0), persistent=False)
53
+
54
+ def forward(self, inp):
55
+ if len(inp.shape) == 3:
56
+ inp = inp.squeeze(1) # 'b 1 nw -> b nw'
57
+
58
+ assert len(inp.shape) == 2
59
+
60
+ if self.dummy.device != inp.device:
61
+ self.to(inp.device)
62
+
63
+ mel = self.mel_stft(inp)
64
+ mel = mel.clamp(min=1e-5).log()
65
+ return mel
66
+
67
+
68
+ # sinusoidal position embedding
69
+
70
+
71
+ class SinusPositionEmbedding(nn.Module):
72
+ def __init__(self, dim):
73
+ super().__init__()
74
+ self.dim = dim
75
+
76
+ def forward(self, x, scale=1000):
77
+ device = x.device
78
+ half_dim = self.dim // 2
79
+ emb = math.log(10000) / (half_dim - 1)
80
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
81
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
82
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
83
+ return emb
84
+
85
+
86
+ # convolutional position embedding
87
+
88
+
89
+ class ConvPositionEmbedding(nn.Module):
90
+ def __init__(self, dim, kernel_size=31, groups=16):
91
+ super().__init__()
92
+ assert kernel_size % 2 != 0
93
+ self.conv1d = nn.Sequential(
94
+ nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
95
+ nn.Mish(),
96
+ nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
97
+ nn.Mish(),
98
+ )
99
+
100
+ def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
101
+ if mask is not None:
102
+ mask = mask[..., None]
103
+ x = x.masked_fill(~mask, 0.0)
104
+
105
+ x = x.permute(0, 2, 1)
106
+ x = self.conv1d(x)
107
+ out = x.permute(0, 2, 1)
108
+
109
+ if mask is not None:
110
+ out = out.masked_fill(~mask, 0.0)
111
+
112
+ return out
113
+
114
+
115
+ class CausalConvPositionEmbedding(nn.Module):
116
+ def __init__(self, dim, kernel_size=31, groups=16):
117
+ super().__init__()
118
+ assert kernel_size % 2 != 0
119
+ self.kernel_size = kernel_size
120
+ self.conv1 = nn.Sequential(
121
+ nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
122
+ nn.Mish(),
123
+ )
124
+ self.conv2 = nn.Sequential(
125
+ nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
126
+ nn.Mish(),
127
+ )
128
+
129
+ def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
130
+ if mask is not None:
131
+ mask = mask[..., None]
132
+ x = x.masked_fill(~mask, 0.0)
133
+
134
+ x = x.permute(0, 2, 1)
135
+ x = F.pad(x, (self.kernel_size - 1, 0, 0, 0))
136
+ x = self.conv1(x)
137
+ x = F.pad(x, (self.kernel_size - 1, 0, 0, 0))
138
+ x = self.conv2(x)
139
+ out = x.permute(0, 2, 1)
140
+
141
+ if mask is not None:
142
+ out = out.masked_fill(~mask, 0.0)
143
+
144
+ return out
145
+
146
+
147
+ # rotary positional embedding related
148
+
149
+
150
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
151
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
152
+ # has some connection to NTK literature
153
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
154
+ # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
155
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
156
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
157
+ t = torch.arange(end, device=freqs.device) # type: ignore
158
+ freqs = torch.outer(t, freqs).float() # type: ignore
159
+ freqs_cos = torch.cos(freqs) # real part
160
+ freqs_sin = torch.sin(freqs) # imaginary part
161
+ return torch.cat([freqs_cos, freqs_sin], dim=-1)
162
+
163
+
164
+ def get_pos_embed_indices(start, length, max_pos, scale=1.0):
165
+ # length = length if isinstance(length, int) else length.max()
166
+ scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
167
+ pos = (
168
+ start.unsqueeze(1)
169
+ + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
170
+ )
171
+ # avoid extra long error.
172
+ pos = torch.where(pos < max_pos, pos, max_pos - 1)
173
+ return pos
174
+
175
+
176
+ # Global Response Normalization layer (Instance Normalization ?)
177
+
178
+
179
+ class GRN(nn.Module):
180
+ def __init__(self, dim):
181
+ super().__init__()
182
+ self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
183
+ self.beta = nn.Parameter(torch.zeros(1, 1, dim))
184
+
185
+ def forward(self, x):
186
+ Gx = torch.norm(x, p=2, dim=1, keepdim=True)
187
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
188
+ return self.gamma * (x * Nx) + self.beta + x
189
+
190
+
191
+ # ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
192
+ # ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
193
+
194
+
195
+ class ConvNeXtV2Block(nn.Module):
196
+ def __init__(
197
+ self,
198
+ dim: int,
199
+ intermediate_dim: int,
200
+ dilation: int = 1,
201
+ ):
202
+ super().__init__()
203
+ padding = (dilation * (7 - 1)) // 2
204
+ self.dwconv = nn.Conv1d(
205
+ dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
206
+ ) # depthwise conv
207
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
208
+ self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
209
+ self.act = nn.GELU()
210
+ self.grn = GRN(intermediate_dim)
211
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
212
+
213
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
214
+ residual = x
215
+ x = x.transpose(1, 2) # b n d -> b d n
216
+ x = self.dwconv(x)
217
+ x = x.transpose(1, 2) # b d n -> b n d
218
+ x = self.norm(x)
219
+ x = self.pwconv1(x)
220
+ x = self.act(x)
221
+ x = self.grn(x)
222
+ x = self.pwconv2(x)
223
+ return residual + x
224
+
225
+
226
+ # AdaLayerNormZero
227
+ # return with modulated x for attn input, and params for later mlp modulation
228
+
229
+
230
+ class AdaLayerNormZero(nn.Module):
231
+ def __init__(self, dim):
232
+ super().__init__()
233
+
234
+ self.silu = nn.SiLU()
235
+ self.linear = nn.Linear(dim, dim * 6)
236
+
237
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
238
+
239
+ def forward(self, x, emb=None):
240
+ emb = self.linear(self.silu(emb))
241
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
242
+
243
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
244
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
245
+
246
+
247
+ # AdaLayerNormZero for final layer
248
+ # return only with modulated x for attn input, cuz no more mlp modulation
249
+
250
+
251
+ class AdaLayerNormZero_Final(nn.Module):
252
+ def __init__(self, dim):
253
+ super().__init__()
254
+
255
+ self.silu = nn.SiLU()
256
+ self.linear = nn.Linear(dim, dim * 2)
257
+
258
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
259
+
260
+ def forward(self, x, emb):
261
+ emb = self.linear(self.silu(emb))
262
+ scale, shift = torch.chunk(emb, 2, dim=1)
263
+
264
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
265
+ return x
266
+
267
+
268
+ # FeedForward
269
+
270
+
271
+ class FeedForward(nn.Module):
272
+ def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
273
+ super().__init__()
274
+ inner_dim = int(dim * mult)
275
+ dim_out = dim_out if dim_out is not None else dim
276
+
277
+ activation = nn.GELU(approximate=approximate)
278
+ project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
279
+ self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
280
+
281
+ def forward(self, x):
282
+ return self.ff(x)
283
+
284
+
285
+ # Attention with possible joint part
286
+ # modified from diffusers/src/diffusers/models/attention_processor.py
287
+
288
+
289
+ class Attention(nn.Module):
290
+ def __init__(
291
+ self,
292
+ processor: JointAttnProcessor | AttnProcessor,
293
+ dim: int,
294
+ heads: int = 8,
295
+ dim_head: int = 64,
296
+ dropout: float = 0.0,
297
+ context_dim: Optional[int] = None, # if not None -> joint attention
298
+ context_pre_only=None,
299
+ ):
300
+ super().__init__()
301
+
302
+ if not hasattr(F, "scaled_dot_product_attention"):
303
+ raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
304
+
305
+ self.processor = processor
306
+
307
+ self.dim = dim
308
+ self.heads = heads
309
+ self.inner_dim = dim_head * heads
310
+ self.dropout = dropout
311
+
312
+ self.context_dim = context_dim
313
+ self.context_pre_only = context_pre_only
314
+
315
+ self.to_q = nn.Linear(dim, self.inner_dim)
316
+ self.to_k = nn.Linear(dim, self.inner_dim)
317
+ self.to_v = nn.Linear(dim, self.inner_dim)
318
+
319
+ if self.context_dim is not None:
320
+ self.to_k_c = nn.Linear(context_dim, self.inner_dim)
321
+ self.to_v_c = nn.Linear(context_dim, self.inner_dim)
322
+ if self.context_pre_only is not None:
323
+ self.to_q_c = nn.Linear(context_dim, self.inner_dim)
324
+
325
+ self.to_out = nn.ModuleList([])
326
+ self.to_out.append(nn.Linear(self.inner_dim, dim))
327
+ self.to_out.append(nn.Dropout(dropout))
328
+
329
+ if self.context_pre_only is not None and not self.context_pre_only:
330
+ self.to_out_c = nn.Linear(self.inner_dim, dim)
331
+
332
+ def forward(
333
+ self,
334
+ x: float["b n d"], # noised input x # noqa: F722
335
+ c: float["b n d"] = None, # context c # noqa: F722
336
+ mask: bool["b n"] | None = None, # noqa: F722
337
+ rope=None, # rotary position embedding for x
338
+ c_rope=None, # rotary position embedding for c
339
+ ) -> torch.Tensor:
340
+ if c is not None:
341
+ return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
342
+ else:
343
+ return self.processor(self, x, mask=mask, rope=rope)
344
+
345
+
346
+ # Attention processor
347
+
348
+
349
+ class AttnProcessor:
350
+ def __init__(self):
351
+ pass
352
+
353
+ def __call__(
354
+ self,
355
+ attn: Attention,
356
+ x: float["b n d"], # noised input x # noqa: F722
357
+ mask: bool["b n"] | None = None, # noqa: F722
358
+ rope=None, # rotary position embedding
359
+ ) -> torch.FloatTensor:
360
+ batch_size = x.shape[0]
361
+
362
+ # `sample` projections.
363
+ query = attn.to_q(x)
364
+ key = attn.to_k(x)
365
+ value = attn.to_v(x)
366
+
367
+ # apply rotary position embedding
368
+ if rope is not None:
369
+ freqs, xpos_scale = rope
370
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
371
+
372
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
373
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
374
+
375
+ # attention
376
+ inner_dim = key.shape[-1]
377
+ head_dim = inner_dim // attn.heads
378
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
379
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
380
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
381
+
382
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
383
+ if mask is not None:
384
+ attn_mask = mask
385
+ if attn_mask.dim() == 2:
386
+ attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
387
+ attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
388
+ else:
389
+ attn_mask = None
390
+
391
+ x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
392
+ x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
393
+ x = x.to(query.dtype)
394
+
395
+ # linear proj
396
+ x = attn.to_out[0](x)
397
+ # dropout
398
+ x = attn.to_out[1](x)
399
+
400
+ if mask is not None:
401
+ if mask.dim() == 2:
402
+ mask = mask.unsqueeze(-1)
403
+ else:
404
+ mask = mask[:, 0, -1].unsqueeze(-1)
405
+ x = x.masked_fill(~mask, 0.0)
406
+
407
+ return x
408
+
409
+
410
+ # Joint Attention processor for MM-DiT
411
+ # modified from diffusers/src/diffusers/models/attention_processor.py
412
+
413
+
414
+ class JointAttnProcessor:
415
+ def __init__(self):
416
+ pass
417
+
418
+ def __call__(
419
+ self,
420
+ attn: Attention,
421
+ x: float["b n d"], # noised input x # noqa: F722
422
+ c: float["b nt d"] = None, # context c, here text # noqa: F722
423
+ mask: bool["b n"] | None = None, # noqa: F722
424
+ rope=None, # rotary position embedding for x
425
+ c_rope=None, # rotary position embedding for c
426
+ ) -> torch.FloatTensor:
427
+ residual = x
428
+
429
+ batch_size = c.shape[0]
430
+
431
+ # `sample` projections.
432
+ query = attn.to_q(x)
433
+ key = attn.to_k(x)
434
+ value = attn.to_v(x)
435
+
436
+ # `context` projections.
437
+ c_query = attn.to_q_c(c)
438
+ c_key = attn.to_k_c(c)
439
+ c_value = attn.to_v_c(c)
440
+
441
+ # apply rope for context and noised input independently
442
+ if rope is not None:
443
+ freqs, xpos_scale = rope
444
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
445
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
446
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
447
+ if c_rope is not None:
448
+ freqs, xpos_scale = c_rope
449
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
450
+ c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
451
+ c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
452
+
453
+ # attention
454
+ query = torch.cat([query, c_query], dim=1)
455
+ key = torch.cat([key, c_key], dim=1)
456
+ value = torch.cat([value, c_value], dim=1)
457
+
458
+ inner_dim = key.shape[-1]
459
+ head_dim = inner_dim // attn.heads
460
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
461
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
462
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
463
+
464
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
465
+ if mask is not None:
466
+ attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
467
+ attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
468
+ attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
469
+ else:
470
+ attn_mask = None
471
+
472
+ x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
473
+ x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
474
+ x = x.to(query.dtype)
475
+
476
+ # Split the attention outputs.
477
+ x, c = (
478
+ x[:, : residual.shape[1]],
479
+ x[:, residual.shape[1]:],
480
+ )
481
+
482
+ # linear proj
483
+ x = attn.to_out[0](x)
484
+ # dropout
485
+ x = attn.to_out[1](x)
486
+ if not attn.context_pre_only:
487
+ c = attn.to_out_c(c)
488
+
489
+ if mask is not None:
490
+ mask = mask.unsqueeze(-1)
491
+ x = x.masked_fill(~mask, 0.0)
492
+ # c = c.masked_fill(~mask, 0.) # no mask for c (text)
493
+
494
+ return x, c
495
+
496
+
497
+ # DiT Block
498
+
499
+
500
+ class DiTBlock(nn.Module):
501
+ def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
502
+ super().__init__()
503
+
504
+ self.attn_norm = AdaLayerNormZero(dim)
505
+ self.attn = Attention(
506
+ processor=AttnProcessor(),
507
+ dim=dim,
508
+ heads=heads,
509
+ dim_head=dim_head,
510
+ dropout=dropout,
511
+ )
512
+
513
+ self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
514
+ self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
515
+
516
+ def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
517
+ # pre-norm & modulation for attention input
518
+ norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
519
+
520
+ # attention
521
+ attn_output = self.attn(x=norm, mask=mask, rope=rope)
522
+
523
+ # process attention output for input x
524
+ x = x + gate_msa.unsqueeze(1) * attn_output
525
+
526
+ ff_norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
527
+ ff_output = self.ff(ff_norm)
528
+ x = x + gate_mlp.unsqueeze(1) * ff_output
529
+
530
+ return x
531
+
532
+
533
+ # MMDiT Block https://arxiv.org/abs/2403.03206
534
+
535
+
536
+ class MMDiTBlock(nn.Module):
537
+ r"""
538
+ modified from diffusers/src/diffusers/models/attention.py
539
+
540
+ notes.
541
+ _c: context related. text, cond, etc. (left part in sd3 fig2.b)
542
+ _x: noised input related. (right part)
543
+ context_pre_only: last layer only do prenorm + modulation cuz no more ffn
544
+ """
545
+
546
+ def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
547
+ super().__init__()
548
+
549
+ self.context_pre_only = context_pre_only
550
+
551
+ self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
552
+ self.attn_norm_x = AdaLayerNormZero(dim)
553
+ self.attn = Attention(
554
+ processor=JointAttnProcessor(),
555
+ dim=dim,
556
+ heads=heads,
557
+ dim_head=dim_head,
558
+ dropout=dropout,
559
+ context_dim=dim,
560
+ context_pre_only=context_pre_only,
561
+ )
562
+
563
+ if not context_pre_only:
564
+ self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
565
+ self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
566
+ else:
567
+ self.ff_norm_c = None
568
+ self.ff_c = None
569
+ self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
570
+ self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
571
+
572
+ def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
573
+ # pre-norm & modulation for attention input
574
+ if self.context_pre_only:
575
+ norm_c = self.attn_norm_c(c, t)
576
+ else:
577
+ norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
578
+ norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
579
+
580
+ # attention
581
+ x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
582
+
583
+ # process attention output for context c
584
+ if self.context_pre_only:
585
+ c = None
586
+ else: # if not last layer
587
+ c = c + c_gate_msa.unsqueeze(1) * c_attn_output
588
+
589
+ norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
590
+ c_ff_output = self.ff_c(norm_c)
591
+ c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
592
+
593
+ # process attention output for input x
594
+ x = x + x_gate_msa.unsqueeze(1) * x_attn_output
595
+
596
+ norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
597
+ x_ff_output = self.ff_x(norm_x)
598
+ x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
599
+
600
+ return c, x
601
+
602
+
603
+ # time step conditioning embedding
604
+
605
+
606
+ class TimestepEmbedding(nn.Module):
607
+ def __init__(self, dim, freq_embed_dim=256):
608
+ super().__init__()
609
+ self.time_embed = SinusPositionEmbedding(freq_embed_dim)
610
+ self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
611
+
612
+ def forward(self, timestep: float["b"]): # noqa: F821
613
+ time_hidden = self.time_embed(timestep)
614
+ time_hidden = time_hidden.to(timestep.dtype)
615
+ time = self.time_mlp(time_hidden) # b d
616
+ return time
cosyvoice/flow/decoder.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Tuple
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from einops import pack, rearrange, repeat
19
+ from cosyvoice.utils.common import mask_to_bias
20
+ from cosyvoice.utils.mask import add_optional_chunk_mask
21
+ from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
22
+ from matcha.models.components.transformer import BasicTransformerBlock
23
+
24
+
25
+ class Transpose(torch.nn.Module):
26
+ def __init__(self, dim0: int, dim1: int):
27
+ super().__init__()
28
+ self.dim0 = dim0
29
+ self.dim1 = dim1
30
+
31
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
32
+ x = torch.transpose(x, self.dim0, self.dim1)
33
+ return x
34
+
35
+
36
+ class CausalConv1d(torch.nn.Conv1d):
37
+ def __init__(
38
+ self,
39
+ in_channels: int,
40
+ out_channels: int,
41
+ kernel_size: int,
42
+ stride: int = 1,
43
+ dilation: int = 1,
44
+ groups: int = 1,
45
+ bias: bool = True,
46
+ padding_mode: str = 'zeros',
47
+ device=None,
48
+ dtype=None
49
+ ) -> None:
50
+ super(CausalConv1d, self).__init__(in_channels, out_channels,
51
+ kernel_size, stride,
52
+ padding=0, dilation=dilation,
53
+ groups=groups, bias=bias,
54
+ padding_mode=padding_mode,
55
+ device=device, dtype=dtype)
56
+ assert stride == 1
57
+ self.causal_padding = kernel_size - 1
58
+
59
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
60
+ x = F.pad(x, (self.causal_padding, 0), value=0.0)
61
+ x = super(CausalConv1d, self).forward(x)
62
+ return x
63
+
64
+
65
+ class CausalBlock1D(Block1D):
66
+ def __init__(self, dim: int, dim_out: int):
67
+ super(CausalBlock1D, self).__init__(dim, dim_out)
68
+ self.block = torch.nn.Sequential(
69
+ CausalConv1d(dim, dim_out, 3),
70
+ Transpose(1, 2),
71
+ nn.LayerNorm(dim_out),
72
+ Transpose(1, 2),
73
+ nn.Mish(),
74
+ )
75
+
76
+ def forward(self, x: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
77
+ output = self.block(x * mask)
78
+ return output * mask
79
+
80
+
81
+ class CausalResnetBlock1D(ResnetBlock1D):
82
+ def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
83
+ super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
84
+ self.block1 = CausalBlock1D(dim, dim_out)
85
+ self.block2 = CausalBlock1D(dim_out, dim_out)
86
+
87
+
88
+ class ConditionalDecoder(nn.Module):
89
+ def __init__(
90
+ self,
91
+ in_channels,
92
+ out_channels,
93
+ channels=(256, 256),
94
+ dropout=0.05,
95
+ attention_head_dim=64,
96
+ n_blocks=1,
97
+ num_mid_blocks=2,
98
+ num_heads=4,
99
+ act_fn="snake",
100
+ ):
101
+ """
102
+ This decoder requires an input with the same shape of the target. So, if your text content
103
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
104
+ """
105
+ super().__init__()
106
+ channels = tuple(channels)
107
+ self.in_channels = in_channels
108
+ self.out_channels = out_channels
109
+
110
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
111
+ time_embed_dim = channels[0] * 4
112
+ self.time_mlp = TimestepEmbedding(
113
+ in_channels=in_channels,
114
+ time_embed_dim=time_embed_dim,
115
+ act_fn="silu",
116
+ )
117
+ self.down_blocks = nn.ModuleList([])
118
+ self.mid_blocks = nn.ModuleList([])
119
+ self.up_blocks = nn.ModuleList([])
120
+
121
+ output_channel = in_channels
122
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
123
+ input_channel = output_channel
124
+ output_channel = channels[i]
125
+ is_last = i == len(channels) - 1
126
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
127
+ transformer_blocks = nn.ModuleList(
128
+ [
129
+ BasicTransformerBlock(
130
+ dim=output_channel,
131
+ num_attention_heads=num_heads,
132
+ attention_head_dim=attention_head_dim,
133
+ dropout=dropout,
134
+ activation_fn=act_fn,
135
+ )
136
+ for _ in range(n_blocks)
137
+ ]
138
+ )
139
+ downsample = (
140
+ Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
141
+ )
142
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
143
+
144
+ for _ in range(num_mid_blocks):
145
+ input_channel = channels[-1]
146
+ out_channels = channels[-1]
147
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
148
+
149
+ transformer_blocks = nn.ModuleList(
150
+ [
151
+ BasicTransformerBlock(
152
+ dim=output_channel,
153
+ num_attention_heads=num_heads,
154
+ attention_head_dim=attention_head_dim,
155
+ dropout=dropout,
156
+ activation_fn=act_fn,
157
+ )
158
+ for _ in range(n_blocks)
159
+ ]
160
+ )
161
+
162
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
163
+
164
+ channels = channels[::-1] + (channels[0],)
165
+ for i in range(len(channels) - 1):
166
+ input_channel = channels[i] * 2
167
+ output_channel = channels[i + 1]
168
+ is_last = i == len(channels) - 2
169
+ resnet = ResnetBlock1D(
170
+ dim=input_channel,
171
+ dim_out=output_channel,
172
+ time_emb_dim=time_embed_dim,
173
+ )
174
+ transformer_blocks = nn.ModuleList(
175
+ [
176
+ BasicTransformerBlock(
177
+ dim=output_channel,
178
+ num_attention_heads=num_heads,
179
+ attention_head_dim=attention_head_dim,
180
+ dropout=dropout,
181
+ activation_fn=act_fn,
182
+ )
183
+ for _ in range(n_blocks)
184
+ ]
185
+ )
186
+ upsample = (
187
+ Upsample1D(output_channel, use_conv_transpose=True)
188
+ if not is_last
189
+ else nn.Conv1d(output_channel, output_channel, 3, padding=1)
190
+ )
191
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
192
+ self.final_block = Block1D(channels[-1], channels[-1])
193
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
194
+ self.initialize_weights()
195
+
196
+ def initialize_weights(self):
197
+ for m in self.modules():
198
+ if isinstance(m, nn.Conv1d):
199
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
200
+ if m.bias is not None:
201
+ nn.init.constant_(m.bias, 0)
202
+ elif isinstance(m, nn.GroupNorm):
203
+ nn.init.constant_(m.weight, 1)
204
+ nn.init.constant_(m.bias, 0)
205
+ elif isinstance(m, nn.Linear):
206
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
207
+ if m.bias is not None:
208
+ nn.init.constant_(m.bias, 0)
209
+
210
+ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
211
+ """Forward pass of the UNet1DConditional model.
212
+
213
+ Args:
214
+ x (torch.Tensor): shape (batch_size, in_channels, time)
215
+ mask (_type_): shape (batch_size, 1, time)
216
+ t (_type_): shape (batch_size)
217
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
218
+ cond (_type_, optional): placeholder for future use. Defaults to None.
219
+
220
+ Raises:
221
+ ValueError: _description_
222
+ ValueError: _description_
223
+
224
+ Returns:
225
+ _type_: _description_
226
+ """
227
+
228
+ t = self.time_embeddings(t).to(t.dtype)
229
+ t = self.time_mlp(t)
230
+
231
+ x = pack([x, mu], "b * t")[0]
232
+
233
+ if spks is not None:
234
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
235
+ x = pack([x, spks], "b * t")[0]
236
+ if cond is not None:
237
+ x = pack([x, cond], "b * t")[0]
238
+
239
+ hiddens = []
240
+ masks = [mask]
241
+ for resnet, transformer_blocks, downsample in self.down_blocks:
242
+ mask_down = masks[-1]
243
+ x = resnet(x, mask_down, t)
244
+ x = rearrange(x, "b c t -> b t c").contiguous()
245
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
246
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
247
+ for transformer_block in transformer_blocks:
248
+ x = transformer_block(
249
+ hidden_states=x,
250
+ attention_mask=attn_mask,
251
+ timestep=t,
252
+ )
253
+ x = rearrange(x, "b t c -> b c t").contiguous()
254
+ hiddens.append(x) # Save hidden states for skip connections
255
+ x = downsample(x * mask_down)
256
+ masks.append(mask_down[:, :, ::2])
257
+ masks = masks[:-1]
258
+ mask_mid = masks[-1]
259
+
260
+ for resnet, transformer_blocks in self.mid_blocks:
261
+ x = resnet(x, mask_mid, t)
262
+ x = rearrange(x, "b c t -> b t c").contiguous()
263
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
264
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
265
+ for transformer_block in transformer_blocks:
266
+ x = transformer_block(
267
+ hidden_states=x,
268
+ attention_mask=attn_mask,
269
+ timestep=t,
270
+ )
271
+ x = rearrange(x, "b t c -> b c t").contiguous()
272
+
273
+ for resnet, transformer_blocks, upsample in self.up_blocks:
274
+ mask_up = masks.pop()
275
+ skip = hiddens.pop()
276
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
277
+ x = resnet(x, mask_up, t)
278
+ x = rearrange(x, "b c t -> b t c").contiguous()
279
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
280
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
281
+ for transformer_block in transformer_blocks:
282
+ x = transformer_block(
283
+ hidden_states=x,
284
+ attention_mask=attn_mask,
285
+ timestep=t,
286
+ )
287
+ x = rearrange(x, "b t c -> b c t").contiguous()
288
+ x = upsample(x * mask_up)
289
+ x = self.final_block(x, mask_up)
290
+ output = self.final_proj(x * mask_up)
291
+ return output * mask
292
+
293
+
294
+ class CausalConditionalDecoder(ConditionalDecoder):
295
+ def __init__(
296
+ self,
297
+ in_channels,
298
+ out_channels,
299
+ channels=(256, 256),
300
+ dropout=0.05,
301
+ attention_head_dim=64,
302
+ n_blocks=1,
303
+ num_mid_blocks=2,
304
+ num_heads=4,
305
+ act_fn="snake",
306
+ static_chunk_size=50,
307
+ num_decoding_left_chunks=2,
308
+ ):
309
+ """
310
+ This decoder requires an input with the same shape of the target. So, if your text content
311
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
312
+ """
313
+ torch.nn.Module.__init__(self)
314
+ channels = tuple(channels)
315
+ self.in_channels = in_channels
316
+ self.out_channels = out_channels
317
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
318
+ time_embed_dim = channels[0] * 4
319
+ self.time_mlp = TimestepEmbedding(
320
+ in_channels=in_channels,
321
+ time_embed_dim=time_embed_dim,
322
+ act_fn="silu",
323
+ )
324
+ self.static_chunk_size = static_chunk_size
325
+ self.num_decoding_left_chunks = num_decoding_left_chunks
326
+ self.down_blocks = nn.ModuleList([])
327
+ self.mid_blocks = nn.ModuleList([])
328
+ self.up_blocks = nn.ModuleList([])
329
+
330
+ output_channel = in_channels
331
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
332
+ input_channel = output_channel
333
+ output_channel = channels[i]
334
+ is_last = i == len(channels) - 1
335
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
336
+ transformer_blocks = nn.ModuleList(
337
+ [
338
+ BasicTransformerBlock(
339
+ dim=output_channel,
340
+ num_attention_heads=num_heads,
341
+ attention_head_dim=attention_head_dim,
342
+ dropout=dropout,
343
+ activation_fn=act_fn,
344
+ )
345
+ for _ in range(n_blocks)
346
+ ]
347
+ )
348
+ downsample = (
349
+ Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3)
350
+ )
351
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
352
+
353
+ for _ in range(num_mid_blocks):
354
+ input_channel = channels[-1]
355
+ out_channels = channels[-1]
356
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
357
+
358
+ transformer_blocks = nn.ModuleList(
359
+ [
360
+ BasicTransformerBlock(
361
+ dim=output_channel,
362
+ num_attention_heads=num_heads,
363
+ attention_head_dim=attention_head_dim,
364
+ dropout=dropout,
365
+ activation_fn=act_fn,
366
+ )
367
+ for _ in range(n_blocks)
368
+ ]
369
+ )
370
+
371
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
372
+
373
+ channels = channels[::-1] + (channels[0],)
374
+ for i in range(len(channels) - 1):
375
+ input_channel = channels[i] * 2
376
+ output_channel = channels[i + 1]
377
+ is_last = i == len(channels) - 2
378
+ resnet = CausalResnetBlock1D(
379
+ dim=input_channel,
380
+ dim_out=output_channel,
381
+ time_emb_dim=time_embed_dim,
382
+ )
383
+ transformer_blocks = nn.ModuleList(
384
+ [
385
+ BasicTransformerBlock(
386
+ dim=output_channel,
387
+ num_attention_heads=num_heads,
388
+ attention_head_dim=attention_head_dim,
389
+ dropout=dropout,
390
+ activation_fn=act_fn,
391
+ )
392
+ for _ in range(n_blocks)
393
+ ]
394
+ )
395
+ upsample = (
396
+ Upsample1D(output_channel, use_conv_transpose=True)
397
+ if not is_last
398
+ else CausalConv1d(output_channel, output_channel, 3)
399
+ )
400
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
401
+ self.final_block = CausalBlock1D(channels[-1], channels[-1])
402
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
403
+ self.initialize_weights()
404
+
405
+ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
406
+ """Forward pass of the UNet1DConditional model.
407
+
408
+ Args:
409
+ x (torch.Tensor): shape (batch_size, in_channels, time)
410
+ mask (_type_): shape (batch_size, 1, time)
411
+ t (_type_): shape (batch_size)
412
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
413
+ cond (_type_, optional): placeholder for future use. Defaults to None.
414
+
415
+ Raises:
416
+ ValueError: _description_
417
+ ValueError: _description_
418
+
419
+ Returns:
420
+ _type_: _description_
421
+ """
422
+ t = self.time_embeddings(t).to(t.dtype)
423
+ t = self.time_mlp(t)
424
+
425
+ x = pack([x, mu], "b * t")[0]
426
+
427
+ if spks is not None:
428
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
429
+ x = pack([x, spks], "b * t")[0]
430
+ if cond is not None:
431
+ x = pack([x, cond], "b * t")[0]
432
+
433
+ hiddens = []
434
+ masks = [mask]
435
+ for resnet, transformer_blocks, downsample in self.down_blocks:
436
+ mask_down = masks[-1]
437
+ x = resnet(x, mask_down, t)
438
+ x = rearrange(x, "b c t -> b t c").contiguous()
439
+ if streaming is True:
440
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
441
+ else:
442
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
443
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
444
+ for transformer_block in transformer_blocks:
445
+ x = transformer_block(
446
+ hidden_states=x,
447
+ attention_mask=attn_mask,
448
+ timestep=t,
449
+ )
450
+ x = rearrange(x, "b t c -> b c t").contiguous()
451
+ hiddens.append(x) # Save hidden states for skip connections
452
+ x = downsample(x * mask_down)
453
+ masks.append(mask_down[:, :, ::2])
454
+ masks = masks[:-1]
455
+ mask_mid = masks[-1]
456
+
457
+ for resnet, transformer_blocks in self.mid_blocks:
458
+ x = resnet(x, mask_mid, t)
459
+ x = rearrange(x, "b c t -> b t c").contiguous()
460
+ if streaming is True:
461
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
462
+ else:
463
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
464
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
465
+ for transformer_block in transformer_blocks:
466
+ x = transformer_block(
467
+ hidden_states=x,
468
+ attention_mask=attn_mask,
469
+ timestep=t,
470
+ )
471
+ x = rearrange(x, "b t c -> b c t").contiguous()
472
+
473
+ for resnet, transformer_blocks, upsample in self.up_blocks:
474
+ mask_up = masks.pop()
475
+ skip = hiddens.pop()
476
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
477
+ x = resnet(x, mask_up, t)
478
+ x = rearrange(x, "b c t -> b t c").contiguous()
479
+ if streaming is True:
480
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
481
+ else:
482
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
483
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
484
+ for transformer_block in transformer_blocks:
485
+ x = transformer_block(
486
+ hidden_states=x,
487
+ attention_mask=attn_mask,
488
+ timestep=t,
489
+ )
490
+ x = rearrange(x, "b t c -> b c t").contiguous()
491
+ x = upsample(x * mask_up)
492
+ x = self.final_block(x, mask_up)
493
+ output = self.final_proj(x * mask_up)
494
+ return output * mask
cosyvoice/flow/flow.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import random
16
+ from typing import Dict, Optional
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.nn import functional as F
20
+ from omegaconf import DictConfig
21
+ from cosyvoice.utils.mask import make_pad_mask
22
+
23
+
24
+ class MaskedDiffWithXvec(torch.nn.Module):
25
+ def __init__(self,
26
+ input_size: int = 512,
27
+ output_size: int = 80,
28
+ spk_embed_dim: int = 192,
29
+ output_type: str = "mel",
30
+ vocab_size: int = 4096,
31
+ input_frame_rate: int = 50,
32
+ only_mask_loss: bool = True,
33
+ encoder: torch.nn.Module = None,
34
+ length_regulator: torch.nn.Module = None,
35
+ decoder: torch.nn.Module = None,
36
+ decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
37
+ 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
38
+ 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
39
+ 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
40
+ 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
41
+ super().__init__()
42
+ self.input_size = input_size
43
+ self.output_size = output_size
44
+ self.decoder_conf = decoder_conf
45
+ self.vocab_size = vocab_size
46
+ self.output_type = output_type
47
+ self.input_frame_rate = input_frame_rate
48
+ logging.info(f"input frame rate={self.input_frame_rate}")
49
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
50
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
51
+ self.encoder = encoder
52
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
53
+ self.decoder = decoder
54
+ self.length_regulator = length_regulator
55
+ self.only_mask_loss = only_mask_loss
56
+
57
+ def forward(
58
+ self,
59
+ batch: dict,
60
+ device: torch.device,
61
+ ) -> Dict[str, Optional[torch.Tensor]]:
62
+ token = batch['speech_token'].to(device)
63
+ token_len = batch['speech_token_len'].to(device)
64
+ feat = batch['speech_feat'].to(device)
65
+ feat_len = batch['speech_feat_len'].to(device)
66
+ embedding = batch['embedding'].to(device)
67
+
68
+ # xvec projection
69
+ embedding = F.normalize(embedding, dim=1)
70
+ embedding = self.spk_embed_affine_layer(embedding)
71
+
72
+ # concat text and prompt_text
73
+ mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
74
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
75
+
76
+ # text encode
77
+ h, h_lengths = self.encoder(token, token_len)
78
+ h = self.encoder_proj(h)
79
+ h, h_lengths = self.length_regulator(h, feat_len)
80
+
81
+ # get conditions
82
+ conds = torch.zeros(feat.shape, device=token.device)
83
+ for i, j in enumerate(feat_len):
84
+ if random.random() < 0.5:
85
+ continue
86
+ index = random.randint(0, int(0.3 * j))
87
+ conds[i, :index] = feat[i, :index]
88
+ conds = conds.transpose(1, 2)
89
+
90
+ mask = (~make_pad_mask(feat_len)).to(h)
91
+ # NOTE this is unnecessary, feat/h already same shape
92
+ loss, _ = self.decoder.compute_loss(
93
+ feat.transpose(1, 2).contiguous(),
94
+ mask.unsqueeze(1),
95
+ h.transpose(1, 2).contiguous(),
96
+ embedding,
97
+ cond=conds
98
+ )
99
+ return {'loss': loss}
100
+
101
+ @torch.inference_mode()
102
+ def inference(self,
103
+ token,
104
+ token_len,
105
+ prompt_token,
106
+ prompt_token_len,
107
+ prompt_feat,
108
+ prompt_feat_len,
109
+ embedding,
110
+ flow_cache):
111
+ assert token.shape[0] == 1
112
+ # xvec projection
113
+ embedding = F.normalize(embedding, dim=1)
114
+ embedding = self.spk_embed_affine_layer(embedding)
115
+
116
+ # concat speech token and prompt speech token
117
+ token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
118
+ token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
119
+ mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
120
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
121
+
122
+ # text encode
123
+ h, h_lengths = self.encoder(token, token_len)
124
+ h = self.encoder_proj(h)
125
+ mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
126
+ h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
127
+
128
+ # get conditions
129
+ conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
130
+ conds[:, :mel_len1] = prompt_feat
131
+ conds = conds.transpose(1, 2)
132
+
133
+ mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
134
+ feat, flow_cache = self.decoder(
135
+ mu=h.transpose(1, 2).contiguous(),
136
+ mask=mask.unsqueeze(1),
137
+ spks=embedding,
138
+ cond=conds,
139
+ n_timesteps=10,
140
+ prompt_len=mel_len1,
141
+ cache=flow_cache
142
+ )
143
+ feat = feat[:, :, mel_len1:]
144
+ assert feat.shape[2] == mel_len2
145
+ return feat.float(), flow_cache
146
+
147
+
148
+ class CausalMaskedDiffWithXvec(torch.nn.Module):
149
+ def __init__(self,
150
+ input_size: int = 512,
151
+ output_size: int = 80,
152
+ spk_embed_dim: int = 192,
153
+ output_type: str = "mel",
154
+ vocab_size: int = 4096,
155
+ input_frame_rate: int = 50,
156
+ only_mask_loss: bool = True,
157
+ token_mel_ratio: int = 2,
158
+ pre_lookahead_len: int = 3,
159
+ encoder: torch.nn.Module = None,
160
+ decoder: torch.nn.Module = None,
161
+ decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
162
+ 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
163
+ 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
164
+ 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
165
+ 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
166
+ super().__init__()
167
+ self.input_size = input_size
168
+ self.output_size = output_size
169
+ self.decoder_conf = decoder_conf
170
+ self.vocab_size = vocab_size
171
+ self.output_type = output_type
172
+ self.input_frame_rate = input_frame_rate
173
+ logging.info(f"input frame rate={self.input_frame_rate}")
174
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
175
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
176
+ self.encoder = encoder
177
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
178
+ self.decoder = decoder
179
+ self.only_mask_loss = only_mask_loss
180
+ self.token_mel_ratio = token_mel_ratio
181
+ self.pre_lookahead_len = pre_lookahead_len
182
+
183
+ def forward(
184
+ self,
185
+ batch: dict,
186
+ device: torch.device,
187
+ ) -> Dict[str, Optional[torch.Tensor]]:
188
+ token = batch['speech_token'].to(device)
189
+ token_len = batch['speech_token_len'].to(device)
190
+ feat = batch['speech_feat'].to(device)
191
+ feat_len = batch['speech_feat_len'].to(device)
192
+ embedding = batch['embedding'].to(device)
193
+
194
+ # NOTE unified training, static_chunk_size > 0 or = 0
195
+ streaming = True if random.random() < 0.5 else False
196
+
197
+ # xvec projection
198
+ embedding = F.normalize(embedding, dim=1)
199
+ embedding = self.spk_embed_affine_layer(embedding)
200
+
201
+ # concat text and prompt_text
202
+ mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
203
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
204
+
205
+ # text encode
206
+ h, h_lengths = self.encoder(token, token_len, streaming=streaming)
207
+ h = self.encoder_proj(h)
208
+
209
+ # get conditions
210
+ conds = torch.zeros(feat.shape, device=token.device)
211
+ for i, j in enumerate(feat_len):
212
+ if random.random() < 0.5:
213
+ continue
214
+ index = random.randint(0, int(0.3 * j))
215
+ conds[i, :index] = feat[i, :index]
216
+ conds = conds.transpose(1, 2)
217
+
218
+ mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
219
+ loss, _ = self.decoder.compute_loss(
220
+ feat.transpose(1, 2).contiguous(),
221
+ mask.unsqueeze(1),
222
+ h.transpose(1, 2).contiguous(),
223
+ embedding,
224
+ cond=conds,
225
+ streaming=streaming,
226
+ )
227
+ return {'loss': loss}
228
+
229
+ @torch.inference_mode()
230
+ def inference(self,
231
+ token,
232
+ token_len,
233
+ prompt_token,
234
+ prompt_token_len,
235
+ prompt_feat,
236
+ prompt_feat_len,
237
+ embedding,
238
+ streaming,
239
+ finalize):
240
+ assert token.shape[0] == 1
241
+ # xvec projection
242
+ embedding = F.normalize(embedding, dim=1)
243
+ embedding = self.spk_embed_affine_layer(embedding)
244
+
245
+ # concat text and prompt_text
246
+ token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
247
+ mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
248
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
249
+
250
+ # text encode
251
+ if finalize is True:
252
+ h, h_lengths = self.encoder(token, token_len, streaming=streaming)
253
+ else:
254
+ token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
255
+ h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming)
256
+ mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
257
+ h = self.encoder_proj(h)
258
+
259
+ # get conditions
260
+ conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
261
+ conds[:, :mel_len1] = prompt_feat
262
+ conds = conds.transpose(1, 2)
263
+
264
+ mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
265
+ feat, _ = self.decoder(
266
+ mu=h.transpose(1, 2).contiguous(),
267
+ mask=mask.unsqueeze(1),
268
+ spks=embedding,
269
+ cond=conds,
270
+ n_timesteps=10,
271
+ streaming=streaming
272
+ )
273
+ feat = feat[:, :, mel_len1:]
274
+ assert feat.shape[2] == mel_len2
275
+ return feat.float(), None
276
+
277
+
278
+ class CausalMaskedDiffWithDiT(torch.nn.Module):
279
+ def __init__(self,
280
+ input_size: int = 512,
281
+ output_size: int = 80,
282
+ spk_embed_dim: int = 192,
283
+ output_type: str = "mel",
284
+ vocab_size: int = 4096,
285
+ input_frame_rate: int = 50,
286
+ only_mask_loss: bool = True,
287
+ token_mel_ratio: int = 2,
288
+ pre_lookahead_len: int = 3,
289
+ pre_lookahead_layer: torch.nn.Module = None,
290
+ decoder: torch.nn.Module = None,
291
+ decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
292
+ 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
293
+ 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
294
+ 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
295
+ 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
296
+ super().__init__()
297
+ self.input_size = input_size
298
+ self.output_size = output_size
299
+ self.decoder_conf = decoder_conf
300
+ self.vocab_size = vocab_size
301
+ self.output_type = output_type
302
+ self.input_frame_rate = input_frame_rate
303
+ logging.info(f"input frame rate={self.input_frame_rate}")
304
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
305
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
306
+ self.pre_lookahead_len = pre_lookahead_len
307
+ self.pre_lookahead_layer = pre_lookahead_layer
308
+ self.decoder = decoder
309
+ self.only_mask_loss = only_mask_loss
310
+ self.token_mel_ratio = token_mel_ratio
311
+
312
+ def forward(
313
+ self,
314
+ batch: dict,
315
+ device: torch.device,
316
+ ) -> Dict[str, Optional[torch.Tensor]]:
317
+ token = batch['speech_token'].to(device)
318
+ token_len = batch['speech_token_len'].to(device)
319
+ feat = batch['speech_feat'].to(device)
320
+ feat_len = batch['speech_feat_len'].to(device)
321
+ embedding = batch['embedding'].to(device)
322
+
323
+ # NOTE unified training, static_chunk_size > 0 or = 0
324
+ streaming = True if random.random() < 0.5 else False
325
+
326
+ # xvec projection
327
+ embedding = F.normalize(embedding, dim=1)
328
+ embedding = self.spk_embed_affine_layer(embedding)
329
+
330
+ # concat text and prompt_text
331
+ mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
332
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
333
+
334
+ # text encode
335
+ h, h_lengths = self.encoder(token, token_len, streaming=streaming)
336
+ h = self.encoder_proj(h)
337
+
338
+ # get conditions
339
+ conds = torch.zeros(feat.shape, device=token.device)
340
+ for i, j in enumerate(feat_len):
341
+ if random.random() < 0.5:
342
+ continue
343
+ index = random.randint(0, int(0.3 * j))
344
+ conds[i, :index] = feat[i, :index]
345
+ conds = conds.transpose(1, 2)
346
+
347
+ mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
348
+ loss, _ = self.decoder.compute_loss(
349
+ feat.transpose(1, 2).contiguous(),
350
+ mask.unsqueeze(1),
351
+ h.transpose(1, 2).contiguous(),
352
+ embedding,
353
+ cond=conds,
354
+ streaming=streaming,
355
+ )
356
+ return {'loss': loss}
357
+
358
+ @torch.inference_mode()
359
+ def inference(self,
360
+ token,
361
+ token_len,
362
+ prompt_token,
363
+ prompt_token_len,
364
+ prompt_feat,
365
+ prompt_feat_len,
366
+ embedding,
367
+ streaming,
368
+ finalize):
369
+ assert token.shape[0] == 1
370
+ # xvec projection
371
+ embedding = F.normalize(embedding, dim=1)
372
+ embedding = self.spk_embed_affine_layer(embedding)
373
+
374
+ # concat text and prompt_text
375
+ token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
376
+ mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
377
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
378
+
379
+ # text encode
380
+ if finalize is True:
381
+ h = self.pre_lookahead_layer(token)
382
+ else:
383
+ h = self.pre_lookahead_layer(token[:, :-self.pre_lookahead_len], context=token[:, -self.pre_lookahead_len:])
384
+ h = h.repeat_interleave(self.token_mel_ratio, dim=1)
385
+ mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
386
+
387
+ # get conditions
388
+ conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
389
+ conds[:, :mel_len1] = prompt_feat
390
+ conds = conds.transpose(1, 2)
391
+
392
+ mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
393
+ feat, _ = self.decoder(
394
+ mu=h.transpose(1, 2).contiguous(),
395
+ mask=mask.unsqueeze(1),
396
+ spks=embedding,
397
+ cond=conds,
398
+ n_timesteps=10,
399
+ streaming=streaming
400
+ )
401
+ feat = feat[:, :, mel_len1:]
402
+ assert feat.shape[2] == mel_len2
403
+ return feat.float(), None
404
+
405
+
406
+ if __name__ == '__main__':
407
+ torch.backends.cudnn.deterministic = True
408
+ torch.backends.cudnn.benchmark = False
409
+ from hyperpyyaml import load_hyperpyyaml
410
+ with open('./pretrained_models/Fun-CosyVoice3-0.5B/cosyvoice3.yaml', 'r') as f:
411
+ configs = load_hyperpyyaml(f, overrides={'llm': None, 'hift': None})
412
+ model = configs['flow']
413
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
414
+ model.to(device)
415
+ model.eval()
416
+ max_len = 10 * model.decoder.estimator.static_chunk_size
417
+ chunk_size = model.decoder.estimator.static_chunk_size
418
+ context_size = model.pre_lookahead_layer.pre_lookahead_len
419
+ token = torch.randint(0, 6561, size=(1, max_len)).to(device)
420
+ token_len = torch.tensor([max_len]).to(device)
421
+ prompt_token = torch.randint(0, 6561, size=(1, chunk_size)).to(device)
422
+ prompt_token_len = torch.tensor([chunk_size]).to(device)
423
+ prompt_feat = torch.rand(1, chunk_size * 2, 80).to(device)
424
+ prompt_feat_len = torch.tensor([chunk_size * 2]).to(device)
425
+ prompt_embedding = torch.rand(1, 192).to(device)
426
+ pred_gt, _ = model.inference(token, token_len, prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, prompt_embedding, streaming=True, finalize=True)
427
+ for i in range(0, max_len, chunk_size):
428
+ finalize = True if i + chunk_size + context_size >= max_len else False
429
+ pred_chunk, _ = model.inference(token[:, :i + chunk_size + context_size], torch.tensor([token[:, :i + chunk_size + context_size].shape[1]]).to(device),
430
+ prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, prompt_embedding, streaming=True, finalize=finalize)
431
+ pred_chunk = pred_chunk[:, :, i * model.token_mel_ratio:]
432
+ print((pred_gt[:, :, i * model.token_mel_ratio: i * model.token_mel_ratio + pred_chunk.shape[2]] - pred_chunk).abs().max().item())
cosyvoice/flow/flow_matching.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ # 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from matcha.models.components.flow_matching import BASECFM
18
+ from cosyvoice.utils.common import set_all_random_seed
19
+
20
+
21
+ class ConditionalCFM(BASECFM):
22
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
23
+ super().__init__(
24
+ n_feats=in_channels,
25
+ cfm_params=cfm_params,
26
+ n_spks=n_spks,
27
+ spk_emb_dim=spk_emb_dim,
28
+ )
29
+ self.t_scheduler = cfm_params.t_scheduler
30
+ self.training_cfg_rate = cfm_params.training_cfg_rate
31
+ self.inference_cfg_rate = cfm_params.inference_cfg_rate
32
+ in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
33
+ # Just change the architecture of the estimator here
34
+ self.estimator = estimator
35
+
36
+ @torch.inference_mode()
37
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
38
+ """Forward diffusion
39
+
40
+ Args:
41
+ mu (torch.Tensor): output of encoder
42
+ shape: (batch_size, n_feats, mel_timesteps)
43
+ mask (torch.Tensor): output_mask
44
+ shape: (batch_size, 1, mel_timesteps)
45
+ n_timesteps (int): number of diffusion steps
46
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
47
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
48
+ shape: (batch_size, spk_emb_dim)
49
+ cond: Not used but kept for future purposes
50
+
51
+ Returns:
52
+ sample: generated mel-spectrogram
53
+ shape: (batch_size, n_feats, mel_timesteps)
54
+ """
55
+
56
+ z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
57
+ cache_size = cache.shape[2]
58
+ # fix prompt and overlap part mu and z
59
+ if cache_size != 0:
60
+ z[:, :, :cache_size] = cache[:, :, :, 0]
61
+ mu[:, :, :cache_size] = cache[:, :, :, 1]
62
+ z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
63
+ mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
64
+ cache = torch.stack([z_cache, mu_cache], dim=-1)
65
+
66
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
67
+ if self.t_scheduler == 'cosine':
68
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
69
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), cache
70
+
71
+ def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False):
72
+ """
73
+ Fixed euler solver for ODEs.
74
+ Args:
75
+ x (torch.Tensor): random noise
76
+ t_span (torch.Tensor): n_timesteps interpolated
77
+ shape: (n_timesteps + 1,)
78
+ mu (torch.Tensor): output of encoder
79
+ shape: (batch_size, n_feats, mel_timesteps)
80
+ mask (torch.Tensor): output_mask
81
+ shape: (batch_size, 1, mel_timesteps)
82
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
83
+ shape: (batch_size, spk_emb_dim)
84
+ cond: Not used but kept for future purposes
85
+ """
86
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
87
+ t = t.unsqueeze(dim=0)
88
+
89
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
90
+ # Or in future might add like a return_all_steps flag
91
+ sol = []
92
+
93
+ # Do not use concat, it may cause memory format changed and trt infer with wrong results!
94
+ # NOTE when flow run in amp mode, x.dtype is float32, which cause nan in trt fp16 inference, so set dtype=spks.dtype
95
+ x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
96
+ mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=spks.dtype)
97
+ mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
98
+ t_in = torch.zeros([2], device=x.device, dtype=spks.dtype)
99
+ spks_in = torch.zeros([2, 80], device=x.device, dtype=spks.dtype)
100
+ cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
101
+ for step in range(1, len(t_span)):
102
+ # Classifier-Free Guidance inference introduced in VoiceBox
103
+ x_in[:] = x
104
+ mask_in[:] = mask
105
+ mu_in[0] = mu
106
+ t_in[:] = t.unsqueeze(0)
107
+ spks_in[0] = spks
108
+ cond_in[0] = cond
109
+ dphi_dt = self.forward_estimator(
110
+ x_in, mask_in,
111
+ mu_in, t_in,
112
+ spks_in,
113
+ cond_in,
114
+ streaming
115
+ )
116
+ dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
117
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
118
+ x = x + dt * dphi_dt
119
+ t = t + dt
120
+ sol.append(x)
121
+ if step < len(t_span) - 1:
122
+ dt = t_span[step + 1] - t
123
+
124
+ return sol[-1].float()
125
+
126
+ def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False):
127
+ if isinstance(self.estimator, torch.nn.Module):
128
+ return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming)
129
+ else:
130
+ [estimator, stream], trt_engine = self.estimator.acquire_estimator()
131
+ # NOTE need to synchronize when switching stream
132
+ torch.cuda.current_stream().synchronize()
133
+ with stream:
134
+ estimator.set_input_shape('x', (2, 80, x.size(2)))
135
+ estimator.set_input_shape('mask', (2, 1, x.size(2)))
136
+ estimator.set_input_shape('mu', (2, 80, x.size(2)))
137
+ estimator.set_input_shape('t', (2,))
138
+ estimator.set_input_shape('spks', (2, 80))
139
+ estimator.set_input_shape('cond', (2, 80, x.size(2)))
140
+ data_ptrs = [x.contiguous().data_ptr(),
141
+ mask.contiguous().data_ptr(),
142
+ mu.contiguous().data_ptr(),
143
+ t.contiguous().data_ptr(),
144
+ spks.contiguous().data_ptr(),
145
+ cond.contiguous().data_ptr(),
146
+ x.data_ptr()]
147
+ for i, j in enumerate(data_ptrs):
148
+ estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
149
+ # run trt engine
150
+ assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
151
+ torch.cuda.current_stream().synchronize()
152
+ self.estimator.release_estimator(estimator, stream)
153
+ return x
154
+
155
+ def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
156
+ """Computes diffusion loss
157
+
158
+ Args:
159
+ x1 (torch.Tensor): Target
160
+ shape: (batch_size, n_feats, mel_timesteps)
161
+ mask (torch.Tensor): target mask
162
+ shape: (batch_size, 1, mel_timesteps)
163
+ mu (torch.Tensor): output of encoder
164
+ shape: (batch_size, n_feats, mel_timesteps)
165
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
166
+ shape: (batch_size, spk_emb_dim)
167
+
168
+ Returns:
169
+ loss: conditional flow matching loss
170
+ y: conditional flow
171
+ shape: (batch_size, n_feats, mel_timesteps)
172
+ """
173
+ b, _, t = mu.shape
174
+
175
+ # random timestep
176
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
177
+ if self.t_scheduler == 'cosine':
178
+ t = 1 - torch.cos(t * 0.5 * torch.pi)
179
+ # sample noise p(x_0)
180
+ z = torch.randn_like(x1)
181
+
182
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
183
+ u = x1 - (1 - self.sigma_min) * z
184
+
185
+ # during training, we randomly drop condition to trade off mode coverage and sample fidelity
186
+ if self.training_cfg_rate > 0:
187
+ cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
188
+ mu = mu * cfg_mask.view(-1, 1, 1)
189
+ spks = spks * cfg_mask.view(-1, 1)
190
+ cond = cond * cfg_mask.view(-1, 1, 1)
191
+
192
+ pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
193
+ loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
194
+ return loss, y
195
+
196
+
197
+ class CausalConditionalCFM(ConditionalCFM):
198
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
199
+ super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
200
+ set_all_random_seed(0)
201
+ self.rand_noise = torch.randn([1, 80, 50 * 300])
202
+
203
+ @torch.inference_mode()
204
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False):
205
+ """Forward diffusion
206
+
207
+ Args:
208
+ mu (torch.Tensor): output of encoder
209
+ shape: (batch_size, n_feats, mel_timesteps)
210
+ mask (torch.Tensor): output_mask
211
+ shape: (batch_size, 1, mel_timesteps)
212
+ n_timesteps (int): number of diffusion steps
213
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
214
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
215
+ shape: (batch_size, spk_emb_dim)
216
+ cond: Not used but kept for future purposes
217
+
218
+ Returns:
219
+ sample: generated mel-spectrogram
220
+ shape: (batch_size, n_feats, mel_timesteps)
221
+ """
222
+
223
+ z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
224
+ # fix prompt and overlap part mu and z
225
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
226
+ if self.t_scheduler == 'cosine':
227
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
228
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None
cosyvoice/flow/length_regulator.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Tuple
15
+ import torch.nn as nn
16
+ import torch
17
+ from torch.nn import functional as F
18
+ from cosyvoice.utils.mask import make_pad_mask
19
+
20
+
21
+ class InterpolateRegulator(nn.Module):
22
+ def __init__(
23
+ self,
24
+ channels: int,
25
+ sampling_ratios: Tuple,
26
+ out_channels: int = None,
27
+ groups: int = 1,
28
+ ):
29
+ super().__init__()
30
+ self.sampling_ratios = sampling_ratios
31
+ out_channels = out_channels or channels
32
+ model = nn.ModuleList([])
33
+ if len(sampling_ratios) > 0:
34
+ for _ in sampling_ratios:
35
+ module = nn.Conv1d(channels, channels, 3, 1, 1)
36
+ norm = nn.GroupNorm(groups, channels)
37
+ act = nn.Mish()
38
+ model.extend([module, norm, act])
39
+ model.append(
40
+ nn.Conv1d(channels, out_channels, 1, 1)
41
+ )
42
+ self.model = nn.Sequential(*model)
43
+
44
+ def forward(self, x, ylens=None):
45
+ # x in (B, T, D)
46
+ mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
47
+ x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
48
+ out = self.model(x).transpose(1, 2).contiguous()
49
+ olens = ylens
50
+ return out * mask, olens
51
+
52
+ def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
53
+ # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
54
+ # NOTE 20 corresponds to token_overlap_len in cosyvoice/cli/model.py
55
+ # x in (B, T, D)
56
+ if x2.shape[1] > 40:
57
+ x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
58
+ x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
59
+ mode='linear')
60
+ x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
61
+ x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
62
+ else:
63
+ x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
64
+ if x1.shape[1] != 0:
65
+ x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
66
+ x = torch.concat([x1, x2], dim=2)
67
+ else:
68
+ x = x2
69
+ out = self.model(x).transpose(1, 2).contiguous()
70
+ return out, mel_len1 + mel_len2
cosyvoice/hifigan/discriminator.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ try:
5
+ from torch.nn.utils.parametrizations import weight_norm, spectral_norm
6
+ except ImportError:
7
+ from torch.nn.utils import weight_norm, spectral_norm
8
+ from typing import List, Optional, Tuple
9
+ from einops import rearrange
10
+ from torchaudio.transforms import Spectrogram
11
+
12
+ LRELU_SLOPE = 0.1
13
+
14
+
15
+ class MultipleDiscriminator(nn.Module):
16
+ def __init__(
17
+ self, mpd: nn.Module, mrd: nn.Module
18
+ ):
19
+ super().__init__()
20
+ self.mpd = mpd
21
+ self.mrd = mrd
22
+
23
+ def forward(self, y: torch.Tensor, y_hat: torch.Tensor):
24
+ y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
25
+ this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mpd(y.unsqueeze(dim=1), y_hat.unsqueeze(dim=1))
26
+ y_d_rs += this_y_d_rs
27
+ y_d_gs += this_y_d_gs
28
+ fmap_rs += this_fmap_rs
29
+ fmap_gs += this_fmap_gs
30
+ this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mrd(y, y_hat)
31
+ y_d_rs += this_y_d_rs
32
+ y_d_gs += this_y_d_gs
33
+ fmap_rs += this_fmap_rs
34
+ fmap_gs += this_fmap_gs
35
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
36
+
37
+
38
+ class MultiResolutionDiscriminator(nn.Module):
39
+ def __init__(
40
+ self,
41
+ fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
42
+ num_embeddings: Optional[int] = None,
43
+ ):
44
+ """
45
+ Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
46
+ Additionally, it allows incorporating conditional information with a learned embeddings table.
47
+
48
+ Args:
49
+ fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
50
+ num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
51
+ Defaults to None.
52
+ """
53
+
54
+ super().__init__()
55
+ self.discriminators = nn.ModuleList(
56
+ [DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes]
57
+ )
58
+
59
+ def forward(
60
+ self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
61
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
62
+ y_d_rs = []
63
+ y_d_gs = []
64
+ fmap_rs = []
65
+ fmap_gs = []
66
+
67
+ for d in self.discriminators:
68
+ y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
69
+ y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
70
+ y_d_rs.append(y_d_r)
71
+ fmap_rs.append(fmap_r)
72
+ y_d_gs.append(y_d_g)
73
+ fmap_gs.append(fmap_g)
74
+
75
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
76
+
77
+
78
+ class DiscriminatorR(nn.Module):
79
+ def __init__(
80
+ self,
81
+ window_length: int,
82
+ num_embeddings: Optional[int] = None,
83
+ channels: int = 32,
84
+ hop_factor: float = 0.25,
85
+ bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
86
+ ):
87
+ super().__init__()
88
+ self.window_length = window_length
89
+ self.hop_factor = hop_factor
90
+ self.spec_fn = Spectrogram(
91
+ n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
92
+ )
93
+ n_fft = window_length // 2 + 1
94
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
95
+ self.bands = bands
96
+ convs = lambda: nn.ModuleList(
97
+ [
98
+ weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
99
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
100
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
101
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
102
+ weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
103
+ ]
104
+ )
105
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
106
+
107
+ if num_embeddings is not None:
108
+ self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
109
+ torch.nn.init.zeros_(self.emb.weight)
110
+
111
+ self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
112
+
113
+ def spectrogram(self, x):
114
+ # Remove DC offset
115
+ x = x - x.mean(dim=-1, keepdims=True)
116
+ # Peak normalize the volume of input audio
117
+ x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
118
+ x = self.spec_fn(x)
119
+ x = torch.view_as_real(x)
120
+ x = rearrange(x, "b f t c -> b c t f")
121
+ # Split into bands
122
+ x_bands = [x[..., b[0]: b[1]] for b in self.bands]
123
+ return x_bands
124
+
125
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
126
+ x_bands = self.spectrogram(x)
127
+ fmap = []
128
+ x = []
129
+ for band, stack in zip(x_bands, self.band_convs):
130
+ for i, layer in enumerate(stack):
131
+ band = layer(band)
132
+ band = torch.nn.functional.leaky_relu(band, 0.1)
133
+ if i > 0:
134
+ fmap.append(band)
135
+ x.append(band)
136
+ x = torch.cat(x, dim=-1)
137
+ if cond_embedding_id is not None:
138
+ emb = self.emb(cond_embedding_id)
139
+ h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
140
+ else:
141
+ h = 0
142
+ x = self.conv_post(x)
143
+ fmap.append(x)
144
+ x += h
145
+
146
+ return x, fmap
147
+
148
+
149
+ class MultiResSpecDiscriminator(torch.nn.Module):
150
+
151
+ def __init__(self,
152
+ fft_sizes=[1024, 2048, 512],
153
+ hop_sizes=[120, 240, 50],
154
+ win_lengths=[600, 1200, 240],
155
+ window="hann_window"):
156
+
157
+ super(MultiResSpecDiscriminator, self).__init__()
158
+ self.discriminators = nn.ModuleList([
159
+ SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
160
+ SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
161
+ SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window)])
162
+
163
+ def forward(self, y, y_hat):
164
+ y_d_rs = []
165
+ y_d_gs = []
166
+ fmap_rs = []
167
+ fmap_gs = []
168
+ for _, d in enumerate(self.discriminators):
169
+ y_d_r, fmap_r = d(y)
170
+ y_d_g, fmap_g = d(y_hat)
171
+ y_d_rs.append(y_d_r)
172
+ fmap_rs.append(fmap_r)
173
+ y_d_gs.append(y_d_g)
174
+ fmap_gs.append(fmap_g)
175
+
176
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
177
+
178
+
179
+ def stft(x, fft_size, hop_size, win_length, window):
180
+ """Perform STFT and convert to magnitude spectrogram.
181
+ Args:
182
+ x (Tensor): Input signal tensor (B, T).
183
+ fft_size (int): FFT size.
184
+ hop_size (int): Hop size.
185
+ win_length (int): Window length.
186
+ window (str): Window function type.
187
+ Returns:
188
+ Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
189
+ """
190
+ x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
191
+
192
+ # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
193
+ return torch.abs(x_stft).transpose(2, 1)
194
+
195
+
196
+ class SpecDiscriminator(nn.Module):
197
+ """docstring for Discriminator."""
198
+
199
+ def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", use_spectral_norm=False):
200
+ super(SpecDiscriminator, self).__init__()
201
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
202
+ self.fft_size = fft_size
203
+ self.shift_size = shift_size
204
+ self.win_length = win_length
205
+ self.window = getattr(torch, window)(win_length)
206
+ self.discriminators = nn.ModuleList([
207
+ norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
208
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
209
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
210
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
211
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
212
+ ])
213
+
214
+ self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
215
+
216
+ def forward(self, y):
217
+
218
+ fmap = []
219
+ y = y.squeeze(1)
220
+ y = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(y.device))
221
+ y = y.unsqueeze(1)
222
+ for _, d in enumerate(self.discriminators):
223
+ y = d(y)
224
+ y = F.leaky_relu(y, LRELU_SLOPE)
225
+ fmap.append(y)
226
+
227
+ y = self.out(y)
228
+ fmap.append(y)
229
+
230
+ return torch.flatten(y, 1, -1), fmap
cosyvoice/hifigan/f0_predictor.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn as nn
16
+ try:
17
+ from torch.nn.utils.parametrizations import weight_norm
18
+ except ImportError:
19
+ from torch.nn.utils import weight_norm
20
+ from cosyvoice.transformer.convolution import CausalConv1d
21
+
22
+
23
+ class ConvRNNF0Predictor(nn.Module):
24
+ def __init__(self,
25
+ num_class: int = 1,
26
+ in_channels: int = 80,
27
+ cond_channels: int = 512
28
+ ):
29
+ super().__init__()
30
+
31
+ self.num_class = num_class
32
+ self.condnet = nn.Sequential(
33
+ weight_norm(
34
+ nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
35
+ ),
36
+ nn.ELU(),
37
+ weight_norm(
38
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
39
+ ),
40
+ nn.ELU(),
41
+ weight_norm(
42
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
43
+ ),
44
+ nn.ELU(),
45
+ weight_norm(
46
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
47
+ ),
48
+ nn.ELU(),
49
+ weight_norm(
50
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
51
+ ),
52
+ nn.ELU(),
53
+ )
54
+ self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
55
+
56
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
57
+ x = self.condnet(x)
58
+ x = x.transpose(1, 2)
59
+ return torch.abs(self.classifier(x).squeeze(-1))
60
+
61
+
62
+ class CausalConvRNNF0Predictor(nn.Module):
63
+ def __init__(self,
64
+ num_class: int = 1,
65
+ in_channels: int = 80,
66
+ cond_channels: int = 512
67
+ ):
68
+ super().__init__()
69
+
70
+ self.num_class = num_class
71
+ self.condnet = nn.Sequential(
72
+ weight_norm(
73
+ CausalConv1d(in_channels, cond_channels, kernel_size=4, causal_type='right')
74
+ ),
75
+ nn.ELU(),
76
+ weight_norm(
77
+ CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
78
+ ),
79
+ nn.ELU(),
80
+ weight_norm(
81
+ CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
82
+ ),
83
+ nn.ELU(),
84
+ weight_norm(
85
+ CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
86
+ ),
87
+ nn.ELU(),
88
+ weight_norm(
89
+ CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
90
+ ),
91
+ nn.ELU(),
92
+ )
93
+ self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
94
+
95
+ def forward(self, x: torch.Tensor, finalize: bool = True) -> torch.Tensor:
96
+ if finalize is True:
97
+ x = self.condnet[0](x)
98
+ else:
99
+ x = self.condnet[0](x[:, :, :-self.condnet[0].causal_padding], x[:, :, -self.condnet[0].causal_padding:])
100
+ for i in range(1, len(self.condnet)):
101
+ x = self.condnet[i](x)
102
+ x = x.transpose(1, 2)
103
+ return torch.abs(self.classifier(x).squeeze(-1))
cosyvoice/hifigan/generator.py ADDED
@@ -0,0 +1,746 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """HIFI-GAN"""
16
+
17
+ from typing import Dict, Optional, List
18
+ import numpy as np
19
+ from scipy.signal import get_window
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from torch.nn import Conv1d
24
+ from torch.nn import ConvTranspose1d
25
+ from torch.nn.utils import remove_weight_norm
26
+ try:
27
+ from torch.nn.utils.parametrizations import weight_norm
28
+ except ImportError:
29
+ from torch.nn.utils import weight_norm
30
+ from torch.distributions.uniform import Uniform
31
+ from cosyvoice.transformer.convolution import CausalConv1d, CausalConv1dDownSample, CausalConv1dUpsample
32
+ from cosyvoice.transformer.activation import Snake
33
+ from cosyvoice.utils.common import get_padding
34
+ from cosyvoice.utils.common import init_weights
35
+
36
+
37
+ """hifigan based generator implementation.
38
+
39
+ This code is modified from https://github.com/jik876/hifi-gan
40
+ ,https://github.com/kan-bayashi/ParallelWaveGAN and
41
+ https://github.com/NVIDIA/BigVGAN
42
+
43
+ """
44
+
45
+
46
+ class ResBlock(torch.nn.Module):
47
+ """Residual block module in HiFiGAN/BigVGAN."""
48
+ def __init__(
49
+ self,
50
+ channels: int = 512,
51
+ kernel_size: int = 3,
52
+ dilations: List[int] = [1, 3, 5],
53
+ causal: bool = False,
54
+ ):
55
+ super(ResBlock, self).__init__()
56
+ self.causal = causal
57
+ self.convs1 = nn.ModuleList()
58
+ self.convs2 = nn.ModuleList()
59
+
60
+ for dilation in dilations:
61
+ self.convs1.append(
62
+ weight_norm(
63
+ Conv1d(
64
+ channels,
65
+ channels,
66
+ kernel_size,
67
+ 1,
68
+ dilation=dilation,
69
+ padding=get_padding(kernel_size, dilation)) if causal is False else
70
+ CausalConv1d(
71
+ channels,
72
+ channels,
73
+ kernel_size,
74
+ 1,
75
+ dilation=dilation,
76
+ causal_type='left'
77
+ )
78
+ )
79
+ )
80
+ self.convs2.append(
81
+ weight_norm(
82
+ Conv1d(
83
+ channels,
84
+ channels,
85
+ kernel_size,
86
+ 1,
87
+ dilation=1,
88
+ padding=get_padding(kernel_size, 1)) if causal is False else
89
+ CausalConv1d(
90
+ channels,
91
+ channels,
92
+ kernel_size,
93
+ 1,
94
+ dilation=1,
95
+ causal_type='left'
96
+ )
97
+ )
98
+ )
99
+ self.convs1.apply(init_weights)
100
+ self.convs2.apply(init_weights)
101
+ self.activations1 = nn.ModuleList([
102
+ Snake(channels, alpha_logscale=False)
103
+ for _ in range(len(self.convs1))
104
+ ])
105
+ self.activations2 = nn.ModuleList([
106
+ Snake(channels, alpha_logscale=False)
107
+ for _ in range(len(self.convs2))
108
+ ])
109
+
110
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
111
+ for idx in range(len(self.convs1)):
112
+ xt = self.activations1[idx](x)
113
+ xt = self.convs1[idx](xt)
114
+ xt = self.activations2[idx](xt)
115
+ xt = self.convs2[idx](xt)
116
+ x = xt + x
117
+ return x
118
+
119
+ def remove_weight_norm(self):
120
+ for idx in range(len(self.convs1)):
121
+ remove_weight_norm(self.convs1[idx])
122
+ remove_weight_norm(self.convs2[idx])
123
+
124
+
125
+ class SineGen(torch.nn.Module):
126
+ """ Definition of sine generator
127
+ SineGen(samp_rate, harmonic_num = 0,
128
+ sine_amp = 0.1, noise_std = 0.003,
129
+ voiced_threshold = 0,
130
+ flag_for_pulse=False)
131
+ samp_rate: sampling rate in Hz
132
+ harmonic_num: number of harmonic overtones (default 0)
133
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
134
+ noise_std: std of Gaussian noise (default 0.003)
135
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
136
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
137
+ Note: when flag_for_pulse is True, the first time step of a voiced
138
+ segment is always sin(np.pi) or cos(0)
139
+ """
140
+
141
+ def __init__(self, samp_rate, harmonic_num=0,
142
+ sine_amp=0.1, noise_std=0.003,
143
+ voiced_threshold=0):
144
+ super(SineGen, self).__init__()
145
+ self.sine_amp = sine_amp
146
+ self.noise_std = noise_std
147
+ self.harmonic_num = harmonic_num
148
+ self.sampling_rate = samp_rate
149
+ self.voiced_threshold = voiced_threshold
150
+
151
+ def _f02uv(self, f0):
152
+ # generate uv signal
153
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
154
+ return uv
155
+
156
+ @torch.no_grad()
157
+ def forward(self, f0):
158
+ """ sine_tensor, uv = forward(f0)
159
+ input F0: tensor(batchsize=1, dim=1, length)
160
+ f0 for unvoiced steps should be 0
161
+ output sine_tensor: tensor(batchsize=1, length, dim)
162
+ output uv: tensor(batchsize=1, length, 1)
163
+ """
164
+ f0 = f0.transpose(1, 2)
165
+ F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
166
+ for i in range(self.harmonic_num + 1):
167
+ F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
168
+
169
+ theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
170
+ u_dist = Uniform(low=-np.pi, high=np.pi)
171
+ phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
172
+ phase_vec[:, 0, :] = 0
173
+
174
+ # generate sine waveforms
175
+ sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
176
+
177
+ # generate uv signal
178
+ uv = self._f02uv(f0)
179
+
180
+ # noise: for unvoiced should be similar to sine_amp
181
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
182
+ # . for voiced regions is self.noise_std
183
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
184
+ noise = noise_amp * torch.randn_like(sine_waves)
185
+
186
+ # first: set the unvoiced part to 0 by uv
187
+ # then: additive noise
188
+ sine_waves = sine_waves * uv + noise
189
+ return sine_waves.transpose(1, 2), uv.transpose(1, 2), noise
190
+
191
+
192
+ class SineGen2(torch.nn.Module):
193
+ """ Definition of sine generator
194
+ SineGen(samp_rate, harmonic_num = 0,
195
+ sine_amp = 0.1, noise_std = 0.003,
196
+ voiced_threshold = 0,
197
+ flag_for_pulse=False)
198
+ samp_rate: sampling rate in Hz
199
+ harmonic_num: number of harmonic overtones (default 0)
200
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
201
+ noise_std: std of Gaussian noise (default 0.003)
202
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
203
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
204
+ Note: when flag_for_pulse is True, the first time step of a voiced
205
+ segment is always sin(np.pi) or cos(0)
206
+ """
207
+
208
+ def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
209
+ sine_amp=0.1, noise_std=0.003,
210
+ voiced_threshold=0,
211
+ flag_for_pulse=False,
212
+ causal=False):
213
+ super(SineGen2, self).__init__()
214
+ self.sine_amp = sine_amp
215
+ self.noise_std = noise_std
216
+ self.harmonic_num = harmonic_num
217
+ self.dim = self.harmonic_num + 1
218
+ self.sampling_rate = samp_rate
219
+ self.voiced_threshold = voiced_threshold
220
+ self.flag_for_pulse = flag_for_pulse
221
+ self.upsample_scale = upsample_scale
222
+ self.causal = causal
223
+ if causal is True:
224
+ self.rand_ini = torch.rand(1, 9)
225
+ self.rand_ini[:, 0] = 0
226
+ self.sine_waves = torch.rand(1, 300 * 24000, 9)
227
+
228
+ def _f02uv(self, f0):
229
+ # generate uv signal
230
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
231
+ return uv
232
+
233
+ def _f02sine(self, f0_values):
234
+ """ f0_values: (batchsize, length, dim)
235
+ where dim indicates fundamental tone and overtones
236
+ """
237
+ # convert to F0 in rad. The interger part n can be ignored
238
+ # because 2 * np.pi * n doesn't affect phase
239
+ rad_values = (f0_values / self.sampling_rate) % 1
240
+
241
+ # initial phase noise (no noise for fundamental component)
242
+ if self.training is False and self.causal is True:
243
+ rad_values[:, 0, :] = rad_values[:, 0, :] + self.rand_ini.to(rad_values.device)
244
+ else:
245
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
246
+ rand_ini[:, 0] = 0
247
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
248
+
249
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
250
+ if not self.flag_for_pulse:
251
+ rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
252
+ scale_factor=1 / self.upsample_scale,
253
+ mode="linear").transpose(1, 2)
254
+
255
+ phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
256
+ phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
257
+ scale_factor=self.upsample_scale, mode="nearest" if self.causal is True else 'linear').transpose(1, 2)
258
+ sines = torch.sin(phase)
259
+ else:
260
+ # If necessary, make sure that the first time step of every
261
+ # voiced segments is sin(pi) or cos(0)
262
+ # This is used for pulse-train generation
263
+
264
+ # identify the last time step in unvoiced segments
265
+ uv = self._f02uv(f0_values)
266
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
267
+ uv_1[:, -1, :] = 1
268
+ u_loc = (uv < 1) * (uv_1 > 0)
269
+
270
+ # get the instantanouse phase
271
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
272
+ # different batch needs to be processed differently
273
+ for idx in range(f0_values.shape[0]):
274
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
275
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
276
+ # stores the accumulation of i.phase within
277
+ # each voiced segments
278
+ tmp_cumsum[idx, :, :] = 0
279
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
280
+
281
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
282
+ # within the previous voiced segment.
283
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
284
+
285
+ # get the sines
286
+ sines = torch.cos(i_phase * 2 * np.pi)
287
+ return sines
288
+
289
+ def forward(self, f0):
290
+ """ sine_tensor, uv = forward(f0)
291
+ input F0: tensor(batchsize=1, length, dim=1)
292
+ f0 for unvoiced steps should be 0
293
+ output sine_tensor: tensor(batchsize=1, length, dim)
294
+ output uv: tensor(batchsize=1, length, 1)
295
+ """
296
+ # fundamental component
297
+ fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
298
+
299
+ # generate sine waveforms
300
+ sine_waves = self._f02sine(fn) * self.sine_amp
301
+
302
+ # generate uv signal
303
+ uv = self._f02uv(f0)
304
+
305
+ # noise: for unvoiced should be similar to sine_amp
306
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
307
+ # . for voiced regions is self.noise_std
308
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
309
+ if self.training is False and self.causal is True:
310
+ noise = noise_amp * self.sine_waves[:, :sine_waves.shape[1]].to(sine_waves.device)
311
+ else:
312
+ noise = noise_amp * torch.randn_like(sine_waves)
313
+
314
+ # first: set the unvoiced part to 0 by uv
315
+ # then: additive noise
316
+ sine_waves = sine_waves * uv + noise
317
+ return sine_waves, uv, noise
318
+
319
+
320
+ class SourceModuleHnNSF(torch.nn.Module):
321
+ """ SourceModule for hn-nsf
322
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
323
+ add_noise_std=0.003, voiced_threshod=0)
324
+ sampling_rate: sampling_rate in Hz
325
+ harmonic_num: number of harmonic above F0 (default: 0)
326
+ sine_amp: amplitude of sine source signal (default: 0.1)
327
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
328
+ note that amplitude of noise in unvoiced is decided
329
+ by sine_amp
330
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
331
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
332
+ F0_sampled (batchsize, length, 1)
333
+ Sine_source (batchsize, length, 1)
334
+ noise_source (batchsize, length 1)
335
+ uv (batchsize, length, 1)
336
+ """
337
+
338
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
339
+ add_noise_std=0.003, voiced_threshod=0, sinegen_type='1', causal=False):
340
+ super(SourceModuleHnNSF, self).__init__()
341
+
342
+ self.sine_amp = sine_amp
343
+ self.noise_std = add_noise_std
344
+
345
+ # to produce sine waveforms
346
+ if sinegen_type == '1':
347
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod)
348
+ else:
349
+ self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num, sine_amp, add_noise_std, voiced_threshod, causal=causal)
350
+
351
+ # to merge source harmonics into a single excitation
352
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
353
+ self.l_tanh = torch.nn.Tanh()
354
+ self.causal = causal
355
+ if causal is True:
356
+ self.uv = torch.rand(1, 300 * 24000, 1)
357
+
358
+ def forward(self, x):
359
+ """
360
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
361
+ F0_sampled (batchsize, length, 1)
362
+ Sine_source (batchsize, length, 1)
363
+ noise_source (batchsize, length 1)
364
+ """
365
+ # source for harmonic branch
366
+ with torch.no_grad():
367
+ sine_wavs, uv, _ = self.l_sin_gen(x)
368
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
369
+
370
+ # source for noise branch, in the same shape as uv
371
+ if self.training is False and self.causal is True:
372
+ noise = self.uv[:, :uv.shape[1]] * self.sine_amp / 3
373
+ else:
374
+ noise = torch.randn_like(uv) * self.sine_amp / 3
375
+ return sine_merge, noise, uv
376
+
377
+
378
+ class HiFTGenerator(nn.Module):
379
+ """
380
+ HiFTNet Generator: Neural Source Filter + ISTFTNet
381
+ https://arxiv.org/abs/2309.09493
382
+ """
383
+ def __init__(
384
+ self,
385
+ in_channels: int = 80,
386
+ base_channels: int = 512,
387
+ nb_harmonics: int = 8,
388
+ sampling_rate: int = 22050,
389
+ nsf_alpha: float = 0.1,
390
+ nsf_sigma: float = 0.003,
391
+ nsf_voiced_threshold: float = 10,
392
+ upsample_rates: List[int] = [8, 8],
393
+ upsample_kernel_sizes: List[int] = [16, 16],
394
+ istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
395
+ resblock_kernel_sizes: List[int] = [3, 7, 11],
396
+ resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
397
+ source_resblock_kernel_sizes: List[int] = [7, 11],
398
+ source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
399
+ lrelu_slope: float = 0.1,
400
+ audio_limit: float = 0.99,
401
+ f0_predictor: torch.nn.Module = None,
402
+ ):
403
+ super(HiFTGenerator, self).__init__()
404
+
405
+ self.out_channels = 1
406
+ self.nb_harmonics = nb_harmonics
407
+ self.sampling_rate = sampling_rate
408
+ self.istft_params = istft_params
409
+ self.lrelu_slope = lrelu_slope
410
+ self.audio_limit = audio_limit
411
+
412
+ self.num_kernels = len(resblock_kernel_sizes)
413
+ self.num_upsamples = len(upsample_rates)
414
+ # NOTE in CosyVoice2, we use the original SineGen implementation
415
+ self.m_source = SourceModuleHnNSF(
416
+ sampling_rate=sampling_rate,
417
+ upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
418
+ harmonic_num=nb_harmonics,
419
+ sine_amp=nsf_alpha,
420
+ add_noise_std=nsf_sigma,
421
+ voiced_threshod=nsf_voiced_threshold,
422
+ sinegen_type='1' if self.sampling_rate == 22050 else '2',
423
+ causal=False)
424
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
425
+
426
+ self.conv_pre = weight_norm(
427
+ Conv1d(in_channels, base_channels, 7, 1, padding=3)
428
+ )
429
+
430
+ # Up
431
+ self.ups = nn.ModuleList()
432
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
433
+ self.ups.append(
434
+ weight_norm(
435
+ ConvTranspose1d(
436
+ base_channels // (2**i),
437
+ base_channels // (2**(i + 1)),
438
+ k,
439
+ u,
440
+ padding=(k - u) // 2,
441
+ )
442
+ )
443
+ )
444
+
445
+ # Down
446
+ self.source_downs = nn.ModuleList()
447
+ self.source_resblocks = nn.ModuleList()
448
+ downsample_rates = [1] + upsample_rates[::-1][:-1]
449
+ downsample_cum_rates = np.cumprod(downsample_rates)
450
+ for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
451
+ if u == 1:
452
+ self.source_downs.append(
453
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
454
+ )
455
+ else:
456
+ self.source_downs.append(
457
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
458
+ )
459
+
460
+ self.source_resblocks.append(
461
+ ResBlock(base_channels // (2 ** (i + 1)), k, d)
462
+ )
463
+
464
+ self.resblocks = nn.ModuleList()
465
+ for i in range(len(self.ups)):
466
+ ch = base_channels // (2**(i + 1))
467
+ for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
468
+ self.resblocks.append(ResBlock(ch, k, d))
469
+
470
+ self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
471
+ self.ups.apply(init_weights)
472
+ self.conv_post.apply(init_weights)
473
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
474
+ self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
475
+ self.f0_predictor = f0_predictor
476
+
477
+ def remove_weight_norm(self):
478
+ print('Removing weight norm...')
479
+ for l in self.ups:
480
+ remove_weight_norm(l)
481
+ for l in self.resblocks:
482
+ l.remove_weight_norm()
483
+ remove_weight_norm(self.conv_pre)
484
+ remove_weight_norm(self.conv_post)
485
+ self.m_source.remove_weight_norm()
486
+ for l in self.source_downs:
487
+ remove_weight_norm(l)
488
+ for l in self.source_resblocks:
489
+ l.remove_weight_norm()
490
+
491
+ def _stft(self, x):
492
+ spec = torch.stft(
493
+ x,
494
+ self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
495
+ return_complex=True)
496
+ spec = torch.view_as_real(spec) # [B, F, TT, 2]
497
+ return spec[..., 0], spec[..., 1]
498
+
499
+ def _istft(self, magnitude, phase):
500
+ magnitude = torch.clip(magnitude, max=1e2)
501
+ real = magnitude * torch.cos(phase)
502
+ img = magnitude * torch.sin(phase)
503
+ inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
504
+ self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
505
+ return inverse_transform
506
+
507
+ def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
508
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
509
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
510
+
511
+ x = self.conv_pre(x)
512
+ for i in range(self.num_upsamples):
513
+ x = F.leaky_relu(x, self.lrelu_slope)
514
+ x = self.ups[i](x)
515
+
516
+ if i == self.num_upsamples - 1:
517
+ x = self.reflection_pad(x)
518
+
519
+ # fusion
520
+ si = self.source_downs[i](s_stft)
521
+ si = self.source_resblocks[i](si)
522
+ x = x + si
523
+
524
+ xs = None
525
+ for j in range(self.num_kernels):
526
+ if xs is None:
527
+ xs = self.resblocks[i * self.num_kernels + j](x)
528
+ else:
529
+ xs += self.resblocks[i * self.num_kernels + j](x)
530
+ x = xs / self.num_kernels
531
+
532
+ x = F.leaky_relu(x)
533
+ x = self.conv_post(x)
534
+ magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
535
+ phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
536
+
537
+ x = self._istft(magnitude, phase)
538
+ x = torch.clamp(x, -self.audio_limit, self.audio_limit)
539
+ return x
540
+
541
+ def forward(
542
+ self,
543
+ batch: dict,
544
+ device: torch.device,
545
+ ) -> Dict[str, Optional[torch.Tensor]]:
546
+ speech_feat = batch['speech_feat'].transpose(1, 2).to(device)
547
+ # mel->f0
548
+ f0 = self.f0_predictor(speech_feat)
549
+ # f0->source
550
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
551
+ s, _, _ = self.m_source(s)
552
+ s = s.transpose(1, 2)
553
+ # mel+source->speech
554
+ generated_speech = self.decode(x=speech_feat, s=s)
555
+ return generated_speech, f0
556
+
557
+ @torch.inference_mode()
558
+ def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
559
+ # mel->f0
560
+ f0 = self.f0_predictor(speech_feat)
561
+ # f0->source
562
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
563
+ s, _, _ = self.m_source(s)
564
+ s = s.transpose(1, 2)
565
+ # use cache_source to avoid glitch
566
+ if cache_source.shape[2] != 0:
567
+ s[:, :, :cache_source.shape[2]] = cache_source
568
+ generated_speech = self.decode(x=speech_feat, s=s)
569
+ return generated_speech, s
570
+
571
+
572
+ class CausalHiFTGenerator(HiFTGenerator):
573
+ """
574
+ HiFTNet Generator: Neural Source Filter + ISTFTNet
575
+ https://arxiv.org/abs/2309.09493
576
+ """
577
+ def __init__(
578
+ self,
579
+ in_channels: int = 80,
580
+ base_channels: int = 512,
581
+ nb_harmonics: int = 8,
582
+ sampling_rate: int = 22050,
583
+ nsf_alpha: float = 0.1,
584
+ nsf_sigma: float = 0.003,
585
+ nsf_voiced_threshold: float = 10,
586
+ upsample_rates: List[int] = [8, 8],
587
+ upsample_kernel_sizes: List[int] = [16, 16],
588
+ istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
589
+ resblock_kernel_sizes: List[int] = [3, 7, 11],
590
+ resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
591
+ source_resblock_kernel_sizes: List[int] = [7, 11],
592
+ source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
593
+ lrelu_slope: float = 0.1,
594
+ audio_limit: float = 0.99,
595
+ conv_pre_look_right: int = 4,
596
+ f0_predictor: torch.nn.Module = None,
597
+ ):
598
+ torch.nn.Module.__init__(self)
599
+
600
+ self.out_channels = 1
601
+ self.nb_harmonics = nb_harmonics
602
+ self.sampling_rate = sampling_rate
603
+ self.istft_params = istft_params
604
+ self.lrelu_slope = lrelu_slope
605
+ self.audio_limit = audio_limit
606
+
607
+ self.num_kernels = len(resblock_kernel_sizes)
608
+ self.num_upsamples = len(upsample_rates)
609
+ self.m_source = SourceModuleHnNSF(
610
+ sampling_rate=sampling_rate,
611
+ upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
612
+ harmonic_num=nb_harmonics,
613
+ sine_amp=nsf_alpha,
614
+ add_noise_std=nsf_sigma,
615
+ voiced_threshod=nsf_voiced_threshold,
616
+ sinegen_type='1' if self.sampling_rate == 22050 else '2',
617
+ causal=True)
618
+ self.upsample_rates = upsample_rates
619
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
620
+
621
+ self.conv_pre = weight_norm(
622
+ CausalConv1d(in_channels, base_channels, conv_pre_look_right + 1, 1, causal_type='right')
623
+ )
624
+
625
+ # Up
626
+ self.ups = nn.ModuleList()
627
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
628
+ self.ups.append(
629
+ weight_norm(
630
+ CausalConv1dUpsample(
631
+ base_channels // (2**i),
632
+ base_channels // (2**(i + 1)),
633
+ k,
634
+ u,
635
+ )
636
+ )
637
+ )
638
+
639
+ # Down
640
+ self.source_downs = nn.ModuleList()
641
+ self.source_resblocks = nn.ModuleList()
642
+ downsample_rates = [1] + upsample_rates[::-1][:-1]
643
+ downsample_cum_rates = np.cumprod(downsample_rates)
644
+ for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
645
+ if u == 1:
646
+ self.source_downs.append(
647
+ CausalConv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1, causal_type='left')
648
+ )
649
+ else:
650
+ self.source_downs.append(
651
+ CausalConv1dDownSample(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u)
652
+ )
653
+
654
+ self.source_resblocks.append(
655
+ ResBlock(base_channels // (2 ** (i + 1)), k, d, causal=True)
656
+ )
657
+
658
+ self.resblocks = nn.ModuleList()
659
+ for i in range(len(self.ups)):
660
+ ch = base_channels // (2**(i + 1))
661
+ for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
662
+ self.resblocks.append(ResBlock(ch, k, d, causal=True))
663
+
664
+ self.conv_post = weight_norm(CausalConv1d(ch, istft_params["n_fft"] + 2, 7, 1, causal_type='left'))
665
+ self.ups.apply(init_weights)
666
+ self.conv_post.apply(init_weights)
667
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
668
+ self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
669
+ self.conv_pre_look_right = conv_pre_look_right
670
+ self.f0_predictor = f0_predictor
671
+
672
+ def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0), finalize: bool = True) -> torch.Tensor:
673
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
674
+ if finalize is True:
675
+ x = self.conv_pre(x)
676
+ else:
677
+ x = self.conv_pre(x[:, :, :-self.conv_pre_look_right], x[:, :, -self.conv_pre_look_right:])
678
+ s_stft_real = s_stft_real[:, :, :-int(np.prod(self.upsample_rates) * self.conv_pre_look_right)]
679
+ s_stft_imag = s_stft_imag[:, :, :-int(np.prod(self.upsample_rates) * self.conv_pre_look_right)]
680
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
681
+
682
+ for i in range(self.num_upsamples):
683
+ x = F.leaky_relu(x, self.lrelu_slope)
684
+ x = self.ups[i](x)
685
+
686
+ if i == self.num_upsamples - 1:
687
+ x = self.reflection_pad(x)
688
+
689
+ # fusion
690
+ si = self.source_downs[i](s_stft)
691
+ si = self.source_resblocks[i](si)
692
+ x = x + si
693
+
694
+ xs = None
695
+ for j in range(self.num_kernels):
696
+ if xs is None:
697
+ xs = self.resblocks[i * self.num_kernels + j](x)
698
+ else:
699
+ xs += self.resblocks[i * self.num_kernels + j](x)
700
+ x = xs / self.num_kernels
701
+
702
+ x = F.leaky_relu(x)
703
+ x = self.conv_post(x)
704
+ magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
705
+ phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
706
+
707
+ x = self._istft(magnitude, phase)
708
+ if finalize is False:
709
+ x = x[:, :-int(np.prod(self.upsample_rates) * self.istft_params['hop_len'])]
710
+ x = torch.clamp(x, -self.audio_limit, self.audio_limit)
711
+ return x
712
+
713
+ @torch.inference_mode()
714
+ def inference(self, speech_feat: torch.Tensor, finalize: bool = True) -> torch.Tensor:
715
+ # mel->f0 NOTE f0_predictor precision is crucial for causal inference, move self.f0_predictor to cpu if necessary
716
+ self.f0_predictor.to('cpu')
717
+ f0 = self.f0_predictor(speech_feat.cpu(), finalize=finalize).to(speech_feat)
718
+ # f0->source
719
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
720
+ s, _, _ = self.m_source(s)
721
+ s = s.transpose(1, 2)
722
+ if finalize is True:
723
+ generated_speech = self.decode(x=speech_feat, s=s, finalize=finalize)
724
+ else:
725
+ generated_speech = self.decode(x=speech_feat[:, :, :-self.f0_predictor.condnet[0].causal_padding], s=s, finalize=finalize)
726
+ return generated_speech, s
727
+
728
+
729
+ if __name__ == '__main__':
730
+ torch.backends.cudnn.deterministic = True
731
+ torch.backends.cudnn.benchmark = False
732
+ from hyperpyyaml import load_hyperpyyaml
733
+ with open('./pretrained_models/Fun-CosyVoice3-0.5B/cosyvoice3.yaml', 'r') as f:
734
+ configs = load_hyperpyyaml(f, overrides={'llm': None, 'flow': None})
735
+ model = configs['hift']
736
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
737
+ model.to(device)
738
+ model.eval()
739
+ max_len, chunk_size, context_size = 300, 30, 8
740
+ mel = torch.rand(1, 80, max_len).to(device)
741
+ pred_gt, _ = model.inference(mel)
742
+ for i in range(0, max_len, chunk_size):
743
+ finalize = True if i + chunk_size + context_size >= max_len else False
744
+ pred_chunk, _ = model.inference(mel[:, :, : i + chunk_size + context_size], finalize=finalize)
745
+ pred_chunk = pred_chunk[:, i * 480:]
746
+ print((pred_gt[:, i * 480:i * 480 + pred_chunk.shape[1]] - pred_chunk).abs().max().item())
cosyvoice/hifigan/hifigan.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from matcha.hifigan.models import feature_loss, generator_loss, discriminator_loss
6
+ from cosyvoice.utils.losses import tpr_loss, mel_loss
7
+
8
+
9
+ class HiFiGan(nn.Module):
10
+ def __init__(self, generator, discriminator, mel_spec_transform,
11
+ multi_mel_spectral_recon_loss_weight=45, feat_match_loss_weight=2.0,
12
+ tpr_loss_weight=1.0, tpr_loss_tau=0.04):
13
+ super(HiFiGan, self).__init__()
14
+ self.generator = generator
15
+ self.discriminator = discriminator
16
+ self.mel_spec_transform = mel_spec_transform
17
+ self.multi_mel_spectral_recon_loss_weight = multi_mel_spectral_recon_loss_weight
18
+ self.feat_match_loss_weight = feat_match_loss_weight
19
+ self.tpr_loss_weight = tpr_loss_weight
20
+ self.tpr_loss_tau = tpr_loss_tau
21
+
22
+ def forward(
23
+ self,
24
+ batch: dict,
25
+ device: torch.device,
26
+ ) -> Dict[str, Optional[torch.Tensor]]:
27
+ if batch['turn'] == 'generator':
28
+ return self.forward_generator(batch, device)
29
+ else:
30
+ return self.forward_discriminator(batch, device)
31
+
32
+ def forward_generator(self, batch, device):
33
+ real_speech = batch['speech'].to(device)
34
+ pitch_feat = batch['pitch_feat'].to(device)
35
+ # 1. calculate generator outputs
36
+ generated_speech, generated_f0 = self.generator(batch, device)
37
+ # 2. calculate discriminator outputs
38
+ y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
39
+ # 3. calculate generator losses, feature loss, mel loss, tpr losses [Optional]
40
+ loss_gen, _ = generator_loss(y_d_gs)
41
+ loss_fm = feature_loss(fmap_rs, fmap_gs)
42
+ loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform)
43
+ if self.tpr_loss_weight != 0:
44
+ loss_tpr = tpr_loss(y_d_gs, y_d_rs, self.tpr_loss_tau)
45
+ else:
46
+ loss_tpr = torch.zeros(1).to(device)
47
+ loss_f0 = F.l1_loss(generated_f0, pitch_feat)
48
+ loss = loss_gen + self.feat_match_loss_weight * loss_fm + \
49
+ self.multi_mel_spectral_recon_loss_weight * loss_mel + \
50
+ self.tpr_loss_weight * loss_tpr + loss_f0
51
+ return {'loss': loss, 'loss_gen': loss_gen, 'loss_fm': loss_fm, 'loss_mel': loss_mel, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0}
52
+
53
+ def forward_discriminator(self, batch, device):
54
+ real_speech = batch['speech'].to(device)
55
+ # 1. calculate generator outputs
56
+ with torch.no_grad():
57
+ generated_speech, generated_f0 = self.generator(batch, device)
58
+ # 2. calculate discriminator outputs
59
+ y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech.detach())
60
+ # 3. calculate discriminator losses, tpr losses [Optional]
61
+ loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs)
62
+ if self.tpr_loss_weight != 0:
63
+ loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
64
+ else:
65
+ loss_tpr = torch.zeros(1).to(device)
66
+ loss = loss_disc + self.tpr_loss_weight * loss_tpr
67
+ return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr}
cosyvoice/llm/llm.py ADDED
@@ -0,0 +1,739 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ # 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li, Qihua, Shengqiang Li)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import queue
16
+ import random
17
+ import time
18
+ import threading
19
+ from typing import Dict, Optional, Callable, List, Generator
20
+ import numpy as np
21
+ import torch
22
+ from torch import nn
23
+ import torch.nn.functional as F
24
+ from transformers import Qwen2ForCausalLM
25
+ from torch.nn.utils.rnn import pad_sequence, unpad_sequence
26
+ from cosyvoice.utils.common import IGNORE_ID
27
+ from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
28
+ from cosyvoice.utils.common import th_accuracy
29
+ from cosyvoice.utils.file_utils import logging
30
+ from cosyvoice.utils.mask import make_pad_mask
31
+
32
+
33
+ class TransformerLM(torch.nn.Module):
34
+ def __init__(
35
+ self,
36
+ text_encoder_input_size: int,
37
+ llm_input_size: int,
38
+ llm_output_size: int,
39
+ text_token_size: int,
40
+ speech_token_size: int,
41
+ text_encoder: torch.nn.Module,
42
+ llm: torch.nn.Module,
43
+ sampling: Callable,
44
+ length_normalized_loss: bool = True,
45
+ lsm_weight: float = 0.0,
46
+ spk_embed_dim: int = 192,
47
+ ):
48
+ super().__init__()
49
+ self.llm_input_size = llm_input_size
50
+ self.speech_token_size = speech_token_size
51
+ # 1. build text token inputs related modules
52
+ self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
53
+ self.text_encoder = text_encoder
54
+ self.text_encoder_affine_layer = nn.Linear(
55
+ self.text_encoder.output_size(),
56
+ llm_input_size
57
+ )
58
+
59
+ # 2. build speech token language model related modules
60
+ self.sos = 0
61
+ self.task_id = 1
62
+ self.eos_token = self.speech_token_size
63
+ self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
64
+ self.llm = llm
65
+ self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
66
+ self.criterion_ce = LabelSmoothingLoss(
67
+ size=speech_token_size + 1,
68
+ padding_idx=IGNORE_ID,
69
+ smoothing=lsm_weight,
70
+ normalize_length=length_normalized_loss,
71
+ )
72
+
73
+ # 3. [Optional] build speech token related modules
74
+ self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
75
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
76
+
77
+ # 4. sampling method
78
+ self.sampling = sampling
79
+
80
+ def encode(
81
+ self,
82
+ text: torch.Tensor,
83
+ text_lengths: torch.Tensor,
84
+ ):
85
+ encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
86
+ encoder_out_lens = encoder_mask.squeeze(1).sum(1)
87
+ encoder_out = self.text_encoder_affine_layer(encoder_out)
88
+ return encoder_out, encoder_out_lens
89
+
90
+ def pad_unpad_sequence(self, sos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
91
+ text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
92
+ speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
93
+ lm_input = [torch.concat([sos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
94
+ for i in range(len(text_token))]
95
+ lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
96
+ lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
97
+ return lm_input, lm_input_len
98
+
99
+ def forward(
100
+ self,
101
+ batch: dict,
102
+ device: torch.device,
103
+ ) -> Dict[str, Optional[torch.Tensor]]:
104
+ """
105
+ Args:
106
+ text: (B, L, D)
107
+ text_lengths: (B,)
108
+ audio: (B, T, N) or (B, T)
109
+ audio_lengths: (B,)
110
+ """
111
+ text_token = batch['text_token'].to(device)
112
+ text_token_len = batch['text_token_len'].to(device)
113
+ speech_token = batch['speech_token'].to(device)
114
+ speech_token_len = batch['speech_token_len'].to(device)
115
+ embedding = batch['embedding'].to(device)
116
+
117
+ # 1. prepare llm_target
118
+ lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
119
+ [self.speech_token_size]) for i in range(text_token.size(0))]
120
+ lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
121
+
122
+ # 1. encode text_token
123
+ text_token = self.text_embedding(text_token)
124
+ text_token, text_token_len = self.encode(text_token, text_token_len)
125
+
126
+ # 2. embedding projection
127
+ embedding = F.normalize(embedding, dim=1)
128
+ embedding = self.spk_embed_affine_layer(embedding)
129
+ embedding = embedding.unsqueeze(1)
130
+
131
+ # 3. sos and task_id
132
+ sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
133
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
134
+
135
+ # 4. encode speech_token
136
+ speech_token = self.speech_embedding(speech_token)
137
+
138
+ # 5. unpad and pad
139
+ lm_input, lm_input_len = self.pad_unpad_sequence(sos_emb, embedding, text_token, text_token_len,
140
+ task_id_emb, speech_token, speech_token_len)
141
+
142
+ # 6. run lm forward
143
+ lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
144
+ logits = self.llm_decoder(lm_output)
145
+ loss = self.criterion_ce(logits, lm_target)
146
+ acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
147
+ return {'loss': loss, 'acc': acc}
148
+
149
+ def sampling_ids(
150
+ self,
151
+ weighted_scores: torch.Tensor,
152
+ decoded_tokens: List,
153
+ sampling: int,
154
+ ignore_eos: bool = True,
155
+ ):
156
+ num_trials, max_trials = 0, 100
157
+ while True:
158
+ top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
159
+ if (not ignore_eos) or (top_ids < self.speech_token_size):
160
+ break
161
+ num_trials += 1
162
+ if num_trials > max_trials:
163
+ raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
164
+ return top_ids
165
+
166
+ @torch.inference_mode()
167
+ def inference(
168
+ self,
169
+ text: torch.Tensor,
170
+ text_len: torch.Tensor,
171
+ prompt_text: torch.Tensor,
172
+ prompt_text_len: torch.Tensor,
173
+ prompt_speech_token: torch.Tensor,
174
+ prompt_speech_token_len: torch.Tensor,
175
+ embedding: torch.Tensor,
176
+ sampling: int = 25,
177
+ max_token_text_ratio: float = 20,
178
+ min_token_text_ratio: float = 2,
179
+ uuid: str = '',
180
+ ) -> Generator[torch.Tensor, None, None]:
181
+ device = text.device
182
+ text = torch.concat([prompt_text, text], dim=1)
183
+ text_len += prompt_text_len
184
+ text = self.text_embedding(text)
185
+
186
+ # 1. encode text
187
+ text, text_len = self.encode(text, text_len)
188
+
189
+ # 2. encode embedding
190
+ if embedding.shape[0] != 0:
191
+ embedding = F.normalize(embedding, dim=1)
192
+ embedding = self.spk_embed_affine_layer(embedding)
193
+ embedding = embedding.unsqueeze(dim=1)
194
+ else:
195
+ embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
196
+
197
+ # 3. concat llm_input
198
+ sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
199
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
200
+ if prompt_speech_token_len != 0:
201
+ prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
202
+ else:
203
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
204
+ lm_input = torch.concat([sos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
205
+
206
+ # 4. cal min/max_length
207
+ min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
208
+ max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
209
+
210
+ # 5. step by step decode
211
+ out_tokens = []
212
+ offset = 0
213
+ att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
214
+ for i in range(max_len):
215
+ y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
216
+ att_cache=att_cache, cnn_cache=cnn_cache,
217
+ att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
218
+ device=lm_input.device)).to(torch.bool))
219
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
220
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False)
221
+ if top_ids == self.eos_token:
222
+ break
223
+ # in stream mode, yield token one by one
224
+ yield top_ids
225
+ out_tokens.append(top_ids)
226
+ offset += lm_input.size(1)
227
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
228
+
229
+
230
+ class Qwen2Encoder(torch.nn.Module):
231
+ def __init__(self, pretrain_path):
232
+ super().__init__()
233
+ self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
234
+
235
+ def forward(self, xs: torch.Tensor, xs_lens: torch.Tensor):
236
+ T = xs.size(1)
237
+ masks = ~make_pad_mask(xs_lens, T)
238
+ outs = self.model(
239
+ inputs_embeds=xs,
240
+ attention_mask=masks,
241
+ output_hidden_states=True,
242
+ return_dict=True,
243
+ )
244
+ return outs.hidden_states[-1], masks.unsqueeze(1)
245
+
246
+ def forward_one_step(self, xs, masks, cache=None):
247
+ input_masks = masks[:, -1, :]
248
+ outs = self.model(
249
+ inputs_embeds=xs,
250
+ attention_mask=input_masks,
251
+ output_hidden_states=True,
252
+ return_dict=True,
253
+ use_cache=True,
254
+ past_key_values=cache,
255
+ )
256
+ xs = outs.hidden_states[-1]
257
+ new_cache = outs.past_key_values
258
+ return xs, new_cache
259
+
260
+
261
+ class Qwen2LM(TransformerLM):
262
+ def __init__(
263
+ self,
264
+ llm_input_size: int,
265
+ llm_output_size: int,
266
+ speech_token_size: int,
267
+ llm: torch.nn.Module,
268
+ sampling: Callable,
269
+ length_normalized_loss: bool = True,
270
+ lsm_weight: float = 0.0,
271
+ mix_ratio: List[int] = [5, 15],
272
+ ):
273
+ torch.nn.Module.__init__(self)
274
+ self.llm_input_size = llm_input_size
275
+ self.llm_output_size = llm_output_size
276
+ self.speech_token_size = speech_token_size
277
+ # 2. build speech token language model related modules
278
+ self.sos = 0
279
+ self.task_id = 1
280
+ self.eos_token = speech_token_size
281
+ self.fill_token = speech_token_size + 2
282
+
283
+ self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
284
+ self.llm = llm
285
+ self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
286
+ self.criterion_ce = LabelSmoothingLoss(
287
+ size=speech_token_size + 3,
288
+ padding_idx=IGNORE_ID,
289
+ smoothing=lsm_weight,
290
+ normalize_length=length_normalized_loss,
291
+ )
292
+
293
+ # 3. [Optional] build speech token related modules
294
+ self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
295
+
296
+ # 4. sampling method
297
+ self.sampling = sampling
298
+ self.mix_ratio = mix_ratio
299
+
300
+ # 5. vllm related
301
+ self.stop_token_ids = [speech_token_size + i for i in range(3)]
302
+ self.vllm_output_queue = {}
303
+
304
+ def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len):
305
+ lm_target, lm_input = [], []
306
+ text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
307
+ speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
308
+ text_token_emb = unpad_sequence(text_token_emb, text_token_len.cpu(), batch_first=True)
309
+ speech_token_emb = unpad_sequence(speech_token_emb, speech_token_len.cpu(), batch_first=True)
310
+ for i in range(len(text_token)):
311
+ # bistream sequence
312
+ if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
313
+ this_lm_target, this_lm_input = [], []
314
+ this_lm_target.append(IGNORE_ID)
315
+ this_lm_input.append(sos_emb.squeeze(dim=0))
316
+ for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
317
+ this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
318
+ this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
319
+ if len(this_text_token) == self.mix_ratio[0]:
320
+ assert len(this_speech_token) == self.mix_ratio[1]
321
+ this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1)
322
+ this_lm_target += this_speech_token
323
+ this_lm_target.append(self.fill_token)
324
+ this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]])
325
+ this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]])
326
+ else:
327
+ this_lm_target += [-1] * len(this_text_token)
328
+ this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist()
329
+ this_lm_target.append(self.eos_token)
330
+ this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:])
331
+ this_lm_input.append(task_id_emb.squeeze(dim=0))
332
+ this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:])
333
+ this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
334
+ # unistream sequence
335
+ else:
336
+ this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token])
337
+ this_lm_input = torch.concat([sos_emb.squeeze(dim=0), text_token_emb[i], task_id_emb.squeeze(dim=0), speech_token_emb[i]], dim=0)
338
+ lm_target.append(this_lm_target)
339
+ lm_input.append(this_lm_input)
340
+ lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
341
+ lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
342
+ lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID)
343
+ return lm_target, lm_input, lm_input_len
344
+
345
+ def forward(
346
+ self,
347
+ batch: dict,
348
+ device: torch.device,
349
+ ) -> Dict[str, Optional[torch.Tensor]]:
350
+ """
351
+ Args:
352
+ text: (B, L, D)
353
+ text_lengths: (B,)
354
+ audio: (B, T, N) or (B, T)
355
+ audio_lengths: (B,)
356
+ """
357
+ text_token = batch['text_token'].to(device)
358
+ text_token_len = batch['text_token_len'].to(device)
359
+ speech_token = batch['speech_token'].to(device)
360
+ speech_token_len = batch['speech_token_len'].to(device)
361
+
362
+ # 1. encode text_token
363
+ text_token_emb = self.llm.model.model.embed_tokens(text_token)
364
+
365
+ # 3. sos and task_id
366
+ sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
367
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
368
+
369
+ # 2. encode speech_token
370
+ speech_token_emb = self.speech_embedding(speech_token)
371
+
372
+ # 3. prepare llm_input/target
373
+ lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
374
+ speech_token, speech_token_emb, speech_token_len)
375
+ lm_target = lm_target.to(device)
376
+
377
+ # 4. run lm forward
378
+ lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
379
+ logits = self.llm_decoder(lm_output)
380
+ loss = self.criterion_ce(logits, lm_target.to(device))
381
+ acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
382
+ return {'loss': loss, 'acc': acc}
383
+
384
+ def forward_dpo(
385
+ self,
386
+ batch: dict,
387
+ device: torch.device,
388
+ ) -> Dict[str, Optional[torch.Tensor]]:
389
+ text_token = batch['text_token'].to(device)
390
+ text_token_len = batch['text_token_len'].to(device)
391
+ speech_token = batch['speech_token'].to(device)
392
+ speech_token_len = batch['speech_token_len'].to(device)
393
+ reject_speech_token = batch['reject_speech_token'].to(device)
394
+ reject_speech_token_len = batch['reject_speech_token_len'].to(device)
395
+
396
+ # 1. encode text_token
397
+ text_token_emb = self.llm.model.model.embed_tokens(text_token)
398
+
399
+ # 3. sos and task_id
400
+ sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
401
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
402
+
403
+ # 2. encode speech_token
404
+ speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
405
+ reject_speech_token = unpad_sequence(reject_speech_token, reject_speech_token_len.cpu(), batch_first=True)
406
+ speech_token_combined = speech_token + reject_speech_token
407
+ speech_token_combined = pad_sequence(speech_token_combined, batch_first=True, padding_value=0)
408
+ speech_token_combined_len = torch.concat([speech_token_len, reject_speech_token_len], dim=0)
409
+ speech_token_combined_emb = self.speech_embedding(speech_token_combined)
410
+
411
+ # 3. prepare llm_input/target
412
+ lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2),
413
+ task_id_emb, speech_token_combined, speech_token_combined_emb, speech_token_combined_len)
414
+ lm_target = lm_target.to(device)
415
+
416
+ # 4. run lm forward
417
+ lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
418
+ logits = self.llm_decoder(lm_output)
419
+ chosen_logits = logits[:text_token.shape[0]]
420
+ rejected_logits = logits[text_token.shape[0]:]
421
+ chosen_lm_target = lm_target[:text_token.shape[0]]
422
+ rejected_lm_target = lm_target[text_token.shape[0]:]
423
+ loss = self.criterion_ce(chosen_logits, chosen_lm_target.to(device))
424
+ acc = th_accuracy(chosen_logits.view(-1, self.speech_token_size + 3), chosen_lm_target, ignore_label=IGNORE_ID)
425
+
426
+ # 5. calculate dpo logits
427
+ chosen_lm_mask = chosen_lm_target == IGNORE_ID
428
+ rejected_lm_mask = rejected_lm_target == IGNORE_ID
429
+ chosen_logps = torch.gather(chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(chosen_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
430
+ rejected_logps = torch.gather(rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(rejected_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
431
+ chosen_logps = (chosen_logps * chosen_lm_mask).sum(dim=-1) / chosen_lm_mask.sum(dim=-1)
432
+ rejected_logps = (rejected_logps * rejected_lm_mask).sum(dim=-1) / rejected_lm_mask.sum(dim=-1)
433
+ return {'loss': loss, 'acc': acc, 'chosen_logps': chosen_logps, 'rejected_logps': rejected_logps}
434
+
435
+ @torch.inference_mode()
436
+ def inference(
437
+ self,
438
+ text: torch.Tensor,
439
+ text_len: torch.Tensor,
440
+ prompt_text: torch.Tensor,
441
+ prompt_text_len: torch.Tensor,
442
+ prompt_speech_token: torch.Tensor,
443
+ prompt_speech_token_len: torch.Tensor,
444
+ embedding: torch.Tensor,
445
+ sampling: int = 25,
446
+ max_token_text_ratio: float = 20,
447
+ min_token_text_ratio: float = 2,
448
+ uuid: str = '',
449
+ ) -> Generator[torch.Tensor, None, None]:
450
+ device = text.device
451
+ text = torch.concat([prompt_text, text], dim=1)
452
+ text_len += prompt_text_len
453
+ text = self.llm.model.model.embed_tokens(text)
454
+
455
+ # 3. concat llm_input
456
+ sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
457
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
458
+ if prompt_speech_token_len != 0:
459
+ prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
460
+ else:
461
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
462
+ lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
463
+
464
+ # 4. cal min/max_length
465
+ min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
466
+ max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
467
+
468
+ # 5. step by step decode
469
+ for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
470
+ yield token
471
+
472
+ @torch.inference_mode()
473
+ def inference_wrapper(self, lm_input, sampling, min_len, max_len, uuid):
474
+ if hasattr(self, 'vllm'):
475
+ from vllm import SamplingParams, RequestOutput
476
+ sampling_params = SamplingParams(top_k=sampling,
477
+ stop_token_ids=self.stop_token_ids,
478
+ min_tokens=min_len,
479
+ max_tokens=max_len)
480
+ with self.lock:
481
+ self.vllm.add_request(uuid, {"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(lm_input.device)}, sampling_params)
482
+ self.vllm_output_queue[uuid] = queue.Queue()
483
+ out_tokens = []
484
+ while True:
485
+ with self.lock:
486
+ if self.vllm_output_queue[uuid].empty() is True:
487
+ request_outputs: List[RequestOutput] = self.vllm.step()
488
+ for request_output in request_outputs:
489
+ top_ids = list(request_output.outputs[0].token_ids)[-1]
490
+ self.vllm_output_queue[request_output.request_id].put(top_ids)
491
+ if self.vllm_output_queue[uuid].empty() is False:
492
+ top_ids = self.vllm_output_queue[uuid].get()
493
+ if top_ids in self.stop_token_ids:
494
+ break
495
+ # in stream mode, yield token one by one
496
+ yield top_ids
497
+ out_tokens.append(top_ids)
498
+ if len(out_tokens) == max_len:
499
+ break
500
+ time.sleep(0.001)
501
+ with self.lock:
502
+ self.vllm_output_queue.pop(uuid)
503
+ else:
504
+ out_tokens = []
505
+ cache = None
506
+ for i in range(max_len):
507
+ y_pred, cache = self.llm.forward_one_step(lm_input,
508
+ masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
509
+ cache=cache)
510
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
511
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False)
512
+ if top_ids in self.stop_token_ids:
513
+ break
514
+ # in stream mode, yield token one by one
515
+ yield top_ids
516
+ out_tokens.append(top_ids)
517
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
518
+
519
+ @torch.inference_mode()
520
+ def inference_bistream(
521
+ self,
522
+ text: Generator,
523
+ prompt_text: torch.Tensor,
524
+ prompt_text_len: torch.Tensor,
525
+ prompt_speech_token: torch.Tensor,
526
+ prompt_speech_token_len: torch.Tensor,
527
+ embedding: torch.Tensor,
528
+ sampling: int = 25,
529
+ max_token_text_ratio: float = 20,
530
+ min_token_text_ratio: float = 2,
531
+ ) -> Generator[torch.Tensor, None, None]:
532
+
533
+ device = prompt_text.device
534
+ # 1. prepare input
535
+ sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
536
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
537
+ if prompt_speech_token_len != 0:
538
+ prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
539
+ else:
540
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
541
+ lm_input = torch.concat([sos_emb], dim=1)
542
+
543
+ # 2. iterate text
544
+ out_tokens = []
545
+ cache = None
546
+ # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
547
+ text_cache = self.llm.model.model.embed_tokens(prompt_text)
548
+ next_fill_index = (int(prompt_speech_token.shape[1] / self.mix_ratio[1]) + 1) * self.mix_ratio[1] - prompt_speech_token.shape[1]
549
+ for this_text in text:
550
+ text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
551
+ # prompt_speech_token_emb not empty, try append to lm_input
552
+ while prompt_speech_token_emb.size(1) != 0:
553
+ if text_cache.size(1) >= self.mix_ratio[0]:
554
+ lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
555
+ logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
556
+ lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
557
+ text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
558
+ else:
559
+ logging.info('not enough text token to decode, wait for more')
560
+ break
561
+ # no prompt_speech_token_emb remain, can decode some speech token
562
+ if prompt_speech_token_emb.size(1) == 0:
563
+ if (len(out_tokens) != 0 and out_tokens[-1] == self.fill_token) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
564
+ logging.info('get fill token, need to append more text token')
565
+ if text_cache.size(1) >= self.mix_ratio[0]:
566
+ lm_input_text = text_cache[:, :self.mix_ratio[0]]
567
+ logging.info('append {} text token'.format(lm_input_text.size(1)))
568
+ if len(out_tokens) != 0 and out_tokens[-1] == self.fill_token:
569
+ lm_input = lm_input_text
570
+ else:
571
+ lm_input = torch.concat([lm_input, lm_input_text], dim=1)
572
+ text_cache = text_cache[:, self.mix_ratio[0]:]
573
+ else:
574
+ logging.info('not enough text token to decode, wait for more')
575
+ continue
576
+ while True:
577
+ seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
578
+ y_pred, cache = self.llm.forward_one_step(lm_input,
579
+ masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
580
+ cache=cache)
581
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
582
+ if next_fill_index != -1 and len(out_tokens) == next_fill_index:
583
+ top_ids = self.fill_token
584
+ next_fill_index += (self.mix_ratio[1] + 1)
585
+ else:
586
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True)
587
+ if top_ids == self.fill_token:
588
+ next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
589
+ logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
590
+ out_tokens.append(top_ids)
591
+ if top_ids >= self.speech_token_size:
592
+ if top_ids == self.fill_token:
593
+ break
594
+ else:
595
+ raise ValueError('should not get token {}'.format(top_ids))
596
+ yield top_ids
597
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
598
+
599
+ # 3. final decode
600
+ lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
601
+ logging.info('no more text token, decode until met eos')
602
+ while True:
603
+ seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
604
+ y_pred, cache = self.llm.forward_one_step(lm_input,
605
+ masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
606
+ cache=cache)
607
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
608
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False)
609
+ out_tokens.append(top_ids)
610
+ if top_ids >= self.speech_token_size:
611
+ if top_ids == self.eos_token:
612
+ break
613
+ else:
614
+ raise ValueError('should not get token {}'.format(top_ids))
615
+ # in stream mode, yield token one by one
616
+ yield top_ids
617
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
618
+
619
+
620
+ class CosyVoice3LM(Qwen2LM):
621
+ def __init__(
622
+ self,
623
+ llm_input_size: int,
624
+ llm_output_size: int,
625
+ speech_token_size: int,
626
+ llm: torch.nn.Module,
627
+ sampling: Callable,
628
+ length_normalized_loss: bool = True,
629
+ lsm_weight: float = 0.0,
630
+ mix_ratio: List[int] = [5, 15],
631
+ ):
632
+ torch.nn.Module.__init__(self)
633
+ self.llm_input_size = llm_input_size
634
+ self.llm_output_size = llm_output_size
635
+ self.speech_token_size = speech_token_size
636
+ # 2. build speech token language model related modules
637
+ self.sos = speech_token_size + 0
638
+ self.eos_token = speech_token_size + 1
639
+ self.task_id = speech_token_size + 2
640
+ self.fill_token = speech_token_size + 3
641
+
642
+ self.llm = llm
643
+ self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 200, bias=False)
644
+ self.criterion_ce = LabelSmoothingLoss(
645
+ size=speech_token_size + 200,
646
+ padding_idx=IGNORE_ID,
647
+ smoothing=lsm_weight,
648
+ normalize_length=length_normalized_loss,
649
+ )
650
+
651
+ # 3. [Optional] build speech token related modules
652
+ self.speech_embedding = torch.nn.Embedding(speech_token_size + 200, llm_input_size)
653
+
654
+ # 4. sampling method
655
+ self.sampling = sampling
656
+ self.mix_ratio = mix_ratio
657
+
658
+ # 5. vllm related
659
+ self.stop_token_ids = [speech_token_size + i for i in range(200)]
660
+ self.vllm_output_queue = {}
661
+
662
+ def forward(
663
+ self,
664
+ batch: dict,
665
+ device: torch.device,
666
+ ) -> Dict[str, Optional[torch.Tensor]]:
667
+ """
668
+ Args:
669
+ text: (B, L, D)
670
+ text_lengths: (B,)
671
+ audio: (B, T, N) or (B, T)
672
+ audio_lengths: (B,)
673
+ """
674
+ text_token = batch['text_token'].to(device)
675
+ text_token_len = batch['text_token_len'].to(device)
676
+ speech_token = batch['speech_token'].to(device)
677
+ speech_token_len = batch['speech_token_len'].to(device)
678
+ # NOTE should append instruct_token to sequence, not implemented yet
679
+ instruct_token = batch['instruct_token'].to(device)
680
+ instruct_token_len = batch['instruct_token_len'].to(device)
681
+
682
+ # 1. encode text_token
683
+ text_token_emb = self.llm.model.model.embed_tokens(text_token)
684
+
685
+ # 3. sos and task_id
686
+ sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
687
+ task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
688
+
689
+ # 2. encode speech_token
690
+ speech_token_emb = self.speech_embedding(speech_token)
691
+
692
+ # 3. prepare llm_input/target
693
+ lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
694
+ speech_token, speech_token_emb, speech_token_len)
695
+ lm_target = lm_target.to(device)
696
+
697
+ # 4. run lm forward
698
+ lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
699
+ logits = self.llm_decoder(lm_output)
700
+ loss = self.criterion_ce(logits, lm_target.to(device))
701
+ acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
702
+ return {'loss': loss, 'acc': acc}
703
+
704
+ @torch.inference_mode()
705
+ def inference(
706
+ self,
707
+ text: torch.Tensor,
708
+ text_len: torch.Tensor,
709
+ prompt_text: torch.Tensor,
710
+ prompt_text_len: torch.Tensor,
711
+ prompt_speech_token: torch.Tensor,
712
+ prompt_speech_token_len: torch.Tensor,
713
+ embedding: torch.Tensor,
714
+ sampling: int = 25,
715
+ max_token_text_ratio: float = 20,
716
+ min_token_text_ratio: float = 2,
717
+ uuid: str = '',
718
+ ) -> Generator[torch.Tensor, None, None]:
719
+ device = text.device
720
+ text = torch.concat([prompt_text, text], dim=1)
721
+ text_len += prompt_text_len
722
+ text = self.llm.model.model.embed_tokens(text)
723
+
724
+ # 3. concat llm_input
725
+ sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
726
+ task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
727
+ if prompt_speech_token_len != 0:
728
+ prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
729
+ else:
730
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
731
+ lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
732
+
733
+ # 4. cal min/max_length
734
+ min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
735
+ max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
736
+
737
+ # 5. step by step decode
738
+ for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
739
+ yield token
cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken ADDED
The diff for this file is too large to render. See raw diff
 
cosyvoice/tokenizer/tokenizer.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import os
3
+ from functools import lru_cache
4
+ from typing import Optional
5
+ import torch
6
+ from transformers import AutoTokenizer
7
+ from whisper.tokenizer import Tokenizer
8
+
9
+ import tiktoken
10
+
11
+ LANGUAGES = {
12
+ "en": "english",
13
+ "zh": "chinese",
14
+ "de": "german",
15
+ "es": "spanish",
16
+ "ru": "russian",
17
+ "ko": "korean",
18
+ "fr": "french",
19
+ "ja": "japanese",
20
+ "pt": "portuguese",
21
+ "tr": "turkish",
22
+ "pl": "polish",
23
+ "ca": "catalan",
24
+ "nl": "dutch",
25
+ "ar": "arabic",
26
+ "sv": "swedish",
27
+ "it": "italian",
28
+ "id": "indonesian",
29
+ "hi": "hindi",
30
+ "fi": "finnish",
31
+ "vi": "vietnamese",
32
+ "he": "hebrew",
33
+ "uk": "ukrainian",
34
+ "el": "greek",
35
+ "ms": "malay",
36
+ "cs": "czech",
37
+ "ro": "romanian",
38
+ "da": "danish",
39
+ "hu": "hungarian",
40
+ "ta": "tamil",
41
+ "no": "norwegian",
42
+ "th": "thai",
43
+ "ur": "urdu",
44
+ "hr": "croatian",
45
+ "bg": "bulgarian",
46
+ "lt": "lithuanian",
47
+ "la": "latin",
48
+ "mi": "maori",
49
+ "ml": "malayalam",
50
+ "cy": "welsh",
51
+ "sk": "slovak",
52
+ "te": "telugu",
53
+ "fa": "persian",
54
+ "lv": "latvian",
55
+ "bn": "bengali",
56
+ "sr": "serbian",
57
+ "az": "azerbaijani",
58
+ "sl": "slovenian",
59
+ "kn": "kannada",
60
+ "et": "estonian",
61
+ "mk": "macedonian",
62
+ "br": "breton",
63
+ "eu": "basque",
64
+ "is": "icelandic",
65
+ "hy": "armenian",
66
+ "ne": "nepali",
67
+ "mn": "mongolian",
68
+ "bs": "bosnian",
69
+ "kk": "kazakh",
70
+ "sq": "albanian",
71
+ "sw": "swahili",
72
+ "gl": "galician",
73
+ "mr": "marathi",
74
+ "pa": "punjabi",
75
+ "si": "sinhala",
76
+ "km": "khmer",
77
+ "sn": "shona",
78
+ "yo": "yoruba",
79
+ "so": "somali",
80
+ "af": "afrikaans",
81
+ "oc": "occitan",
82
+ "ka": "georgian",
83
+ "be": "belarusian",
84
+ "tg": "tajik",
85
+ "sd": "sindhi",
86
+ "gu": "gujarati",
87
+ "am": "amharic",
88
+ "yi": "yiddish",
89
+ "lo": "lao",
90
+ "uz": "uzbek",
91
+ "fo": "faroese",
92
+ "ht": "haitian creole",
93
+ "ps": "pashto",
94
+ "tk": "turkmen",
95
+ "nn": "nynorsk",
96
+ "mt": "maltese",
97
+ "sa": "sanskrit",
98
+ "lb": "luxembourgish",
99
+ "my": "myanmar",
100
+ "bo": "tibetan",
101
+ "tl": "tagalog",
102
+ "mg": "malagasy",
103
+ "as": "assamese",
104
+ "tt": "tatar",
105
+ "haw": "hawaiian",
106
+ "ln": "lingala",
107
+ "ha": "hausa",
108
+ "ba": "bashkir",
109
+ "jw": "javanese",
110
+ "su": "sundanese",
111
+ "yue": "cantonese",
112
+ "minnan": "minnan",
113
+ "wuyu": "wuyu",
114
+ "dialect": "dialect",
115
+ "zh/en": "zh/en",
116
+ "en/zh": "en/zh",
117
+ }
118
+
119
+ # language code lookup by name, with a few language aliases
120
+ TO_LANGUAGE_CODE = {
121
+ **{language: code for code, language in LANGUAGES.items()},
122
+ "burmese": "my",
123
+ "valencian": "ca",
124
+ "flemish": "nl",
125
+ "haitian": "ht",
126
+ "letzeburgesch": "lb",
127
+ "pushto": "ps",
128
+ "panjabi": "pa",
129
+ "moldavian": "ro",
130
+ "moldovan": "ro",
131
+ "sinhalese": "si",
132
+ "castilian": "es",
133
+ "mandarin": "zh",
134
+ }
135
+
136
+ AUDIO_EVENT = {
137
+ "ASR": "ASR",
138
+ "AED": "AED",
139
+ "SER": "SER",
140
+ "Speech": "Speech",
141
+ "/Speech": "/Speech",
142
+ "BGM": "BGM",
143
+ "/BGM": "/BGM",
144
+ "Laughter": "Laughter",
145
+ "/Laughter": "/Laughter",
146
+ "Applause": "Applause",
147
+ "/Applause": "/Applause",
148
+ }
149
+
150
+ EMOTION = {
151
+ "HAPPY": "HAPPY",
152
+ "SAD": "SAD",
153
+ "ANGRY": "ANGRY",
154
+ "NEUTRAL": "NEUTRAL",
155
+ }
156
+
157
+ TTS_Vocal_Token = {
158
+ "TTS/B": "TTS/B",
159
+ "TTS/O": "TTS/O",
160
+ "TTS/Q": "TTS/Q",
161
+ "TTS/A": "TTS/A",
162
+ "TTS/CO": "TTS/CO",
163
+ "TTS/CL": "TTS/CL",
164
+ "TTS/H": "TTS/H",
165
+ **{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)}
166
+ }
167
+
168
+
169
+ @lru_cache(maxsize=None)
170
+ def get_encoding(name: str = "gpt2", num_languages: int = 99):
171
+ vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
172
+ ranks = {
173
+ base64.b64decode(token): int(rank)
174
+ for token, rank in (line.split() for line in open(vocab_path) if line)
175
+ }
176
+ n_vocab = len(ranks)
177
+ special_tokens = {}
178
+
179
+ specials = [
180
+ "<|endoftext|>",
181
+ "<|startoftranscript|>",
182
+ *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
183
+ *[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
184
+ *[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
185
+ "<|translate|>",
186
+ "<|transcribe|>",
187
+ "<|startoflm|>",
188
+ "<|startofprev|>",
189
+ "<|nospeech|>",
190
+ "<|notimestamps|>",
191
+ *[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], # register special tokens for ASR
192
+ *[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], # register special tokens for TTS
193
+ *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
194
+ ]
195
+
196
+ for token in specials:
197
+ special_tokens[token] = n_vocab
198
+ n_vocab += 1
199
+
200
+ return tiktoken.Encoding(
201
+ name=os.path.basename(vocab_path),
202
+ explicit_n_vocab=n_vocab,
203
+ pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
204
+ mergeable_ranks=ranks,
205
+ special_tokens=special_tokens,
206
+ )
207
+
208
+
209
+ @lru_cache(maxsize=None)
210
+ def get_tokenizer(
211
+ multilingual: bool,
212
+ *,
213
+ num_languages: int = 99,
214
+ language: Optional[str] = None,
215
+ task: Optional[str] = None, # Literal["transcribe", "translate", None]
216
+ ) -> Tokenizer:
217
+ if language is not None:
218
+ language = language.lower()
219
+ if language not in LANGUAGES:
220
+ if language in TO_LANGUAGE_CODE:
221
+ language = TO_LANGUAGE_CODE[language]
222
+ else:
223
+ raise ValueError(f"Unsupported language: {language}")
224
+
225
+ if multilingual:
226
+ encoding_name = "multilingual_zh_ja_yue_char_del"
227
+ language = language or "en"
228
+ task = task or "transcribe"
229
+ else:
230
+ encoding_name = "gpt2"
231
+ language = None
232
+ task = None
233
+
234
+ encoding = get_encoding(name=encoding_name, num_languages=num_languages)
235
+
236
+ return Tokenizer(
237
+ encoding=encoding, num_languages=num_languages, language=language, task=task
238
+ )
239
+
240
+
241
+ class CosyVoice2Tokenizer():
242
+ def __init__(self, token_path, skip_special_tokens=True):
243
+ super().__init__()
244
+ # NOTE: non-chat model, all these special tokens keep randomly initialized.
245
+ special_tokens = {
246
+ 'eos_token': '<|endoftext|>',
247
+ 'pad_token': '<|endoftext|>',
248
+ 'additional_special_tokens': [
249
+ '<|im_start|>', '<|im_end|>', '<|endofprompt|>',
250
+ '[breath]', '<strong>', '</strong>', '[noise]',
251
+ '[laughter]', '[cough]', '[clucking]', '[accent]',
252
+ '[quick_breath]',
253
+ "<laughter>", "</laughter>",
254
+ "[hissing]", "[sigh]", "[vocalized-noise]",
255
+ "[lipsmack]", "[mn]"
256
+ ]
257
+ }
258
+ self.special_tokens = special_tokens
259
+ self.tokenizer = AutoTokenizer.from_pretrained(token_path)
260
+ self.tokenizer.add_special_tokens(special_tokens)
261
+ self.skip_special_tokens = skip_special_tokens
262
+
263
+ def encode(self, text, **kwargs):
264
+ tokens = self.tokenizer([text], return_tensors="pt")
265
+ tokens = tokens["input_ids"][0].cpu().tolist()
266
+ return tokens
267
+
268
+ def decode(self, tokens):
269
+ tokens = torch.tensor(tokens, dtype=torch.int64)
270
+ text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0]
271
+ return text
272
+
273
+
274
+ class CosyVoice3Tokenizer(CosyVoice2Tokenizer):
275
+ def __init__(self, token_path, skip_special_tokens=True):
276
+ # NOTE: non-chat model, all these special tokens keep randomly initialized.
277
+ special_tokens = {
278
+ 'eos_token': '<|endoftext|>',
279
+ 'pad_token': '<|endoftext|>',
280
+ 'additional_special_tokens': [
281
+ '<|im_start|>', '<|im_end|>', '<|endofprompt|>',
282
+ '[breath]', '<strong>', '</strong>', '[noise]',
283
+ '[laughter]', '[cough]', '[clucking]', '[accent]',
284
+ '[quick_breath]',
285
+ "<laughter>", "</laughter>",
286
+ "[hissing]", "[sigh]", "[vocalized-noise]",
287
+ "[lipsmack]", "[mn]", "<|endofsystem|>",
288
+ "[AA]", "[AA0]", "[AA1]", "[AA2]", "[AE]", "[AE0]", "[AE1]", "[AE2]", "[AH]", "[AH0]", "[AH1]", "[AH2]",
289
+ "[AO]", "[AO0]", "[AO1]", "[AO2]", "[AW]", "[AW0]", "[AW1]", "[AW2]", "[AY]", "[AY0]", "[AY1]", "[AY2]",
290
+ "[B]", "[CH]", "[D]", "[DH]", "[EH]", "[EH0]", "[EH1]", "[EH2]", "[ER]", "[ER0]", "[ER1]", "[ER2]", "[EY]",
291
+ "[EY0]", "[EY1]", "[EY2]", "[F]", "[G]", "[HH]", "[IH]", "[IH0]", "[IH1]", "[IH2]", "[IY]", "[IY0]", "[IY1]",
292
+ "[IY2]", "[JH]", "[K]", "[L]", "[M]", "[N]", "[NG]", "[OW]", "[OW0]", "[OW1]", "[OW2]", "[OY]", "[OY0]",
293
+ "[OY1]", "[OY2]", "[P]", "[R]", "[S]", "[SH]", "[T]", "[TH]", "[UH]", "[UH0]", "[UH1]", "[UH2]", "[UW]",
294
+ "[UW0]", "[UW1]", "[UW2]", "[V]", "[W]", "[Y]", "[Z]", "[ZH]",
295
+ "[a]", "[ai]", "[an]", "[ang]", "[ao]", "[b]", "[c]", "[ch]", "[d]", "[e]", "[ei]", "[en]", "[eng]", "[f]",
296
+ "[g]", "[h]", "[i]", "[ian]", "[in]", "[ing]", "[iu]", "[ià]", "[iàn]", "[iàng]", "[iào]", "[iá]", "[ián]",
297
+ "[iáng]", "[iáo]", "[iè]", "[ié]", "[iòng]", "[ióng]", "[iù]", "[iú]", "[iā]", "[iān]", "[iāng]", "[iāo]",
298
+ "[iē]", "[iě]", "[iōng]", "[iū]", "[iǎ]", "[iǎn]", "[iǎng]", "[iǎo]", "[iǒng]", "[iǔ]", "[j]", "[k]", "[l]",
299
+ "[m]", "[n]", "[o]", "[ong]", "[ou]", "[p]", "[q]", "[r]", "[s]", "[sh]", "[t]", "[u]", "[uang]", "[ue]",
300
+ "[un]", "[uo]", "[uà]", "[uài]", "[uàn]", "[uàng]", "[uá]", "[uái]", "[uán]", "[uáng]", "[uè]", "[ué]", "[uì]",
301
+ "[uí]", "[uò]", "[uó]", "[uā]", "[uāi]", "[uān]", "[uāng]", "[uē]", "[uě]", "[uī]", "[uō]", "[uǎ]", "[uǎi]",
302
+ "[uǎn]", "[uǎng]", "[uǐ]", "[uǒ]", "[vè]", "[w]", "[x]", "[y]", "[z]", "[zh]", "[à]", "[ài]", "[àn]", "[àng]",
303
+ "[ào]", "[á]", "[ái]", "[án]", "[��ng]", "[áo]", "[è]", "[èi]", "[èn]", "[èng]", "[èr]", "[é]", "[éi]", "[én]",
304
+ "[éng]", "[ér]", "[ì]", "[ìn]", "[ìng]", "[í]", "[ín]", "[íng]", "[ò]", "[òng]", "[òu]", "[ó]", "[óng]", "[óu]",
305
+ "[ù]", "[ùn]", "[ú]", "[ún]", "[ā]", "[āi]", "[ān]", "[āng]", "[āo]", "[ē]", "[ēi]", "[ēn]", "[ēng]", "[ě]",
306
+ "[ěi]", "[ěn]", "[ěng]", "[ěr]", "[ī]", "[īn]", "[īng]", "[ō]", "[ōng]", "[ōu]", "[ū]", "[ūn]", "[ǎ]", "[ǎi]",
307
+ "[ǎn]", "[ǎng]", "[ǎo]", "[ǐ]", "[ǐn]", "[ǐng]", "[ǒ]", "[ǒng]", "[ǒu]", "[ǔ]", "[ǔn]", "[ǘ]", "[ǚ]", "[ǜ]"
308
+ ]
309
+ }
310
+ self.special_tokens = special_tokens
311
+ self.tokenizer = AutoTokenizer.from_pretrained(token_path)
312
+ self.tokenizer.add_special_tokens(special_tokens)
313
+ self.skip_special_tokens = skip_special_tokens
314
+
315
+
316
+ @lru_cache(maxsize=None)
317
+ def get_qwen_tokenizer(
318
+ token_path: str,
319
+ skip_special_tokens: bool,
320
+ version: str = 'cosyvoice2'
321
+ ):
322
+ if version == 'cosyvoice2':
323
+ return CosyVoice2Tokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
324
+ elif version == 'cosyvoice3':
325
+ return CosyVoice3Tokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
326
+ else:
327
+ raise ValueError
cosyvoice/transformer/__init__.py ADDED
File without changes
cosyvoice/transformer/activation.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
2
+ # 2020 Northwestern Polytechnical University (Pengcheng Guo)
3
+ # 2020 Mobvoi Inc (Binbin Zhang)
4
+ # 2024 Alibaba Inc (Xiang Lyu)
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Swish() activation function for Conformer."""
18
+
19
+ import torch
20
+ from torch import nn, sin, pow
21
+ from torch.nn import Parameter
22
+
23
+
24
+ class Swish(torch.nn.Module):
25
+ """Construct an Swish object."""
26
+
27
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
28
+ """Return Swish activation function."""
29
+ return x * torch.sigmoid(x)
30
+
31
+
32
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
33
+ # LICENSE is in incl_licenses directory.
34
+ class Snake(nn.Module):
35
+ '''
36
+ Implementation of a sine-based periodic activation function
37
+ Shape:
38
+ - Input: (B, C, T)
39
+ - Output: (B, C, T), same shape as the input
40
+ Parameters:
41
+ - alpha - trainable parameter
42
+ References:
43
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
44
+ https://arxiv.org/abs/2006.08195
45
+ Examples:
46
+ >>> a1 = snake(256)
47
+ >>> x = torch.randn(256)
48
+ >>> x = a1(x)
49
+ '''
50
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
51
+ '''
52
+ Initialization.
53
+ INPUT:
54
+ - in_features: shape of the input
55
+ - alpha: trainable parameter
56
+ alpha is initialized to 1 by default, higher values = higher-frequency.
57
+ alpha will be trained along with the rest of your model.
58
+ '''
59
+ super(Snake, self).__init__()
60
+ self.in_features = in_features
61
+
62
+ # initialize alpha
63
+ self.alpha_logscale = alpha_logscale
64
+ if self.alpha_logscale: # log scale alphas initialized to zeros
65
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
66
+ else: # linear scale alphas initialized to ones
67
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
68
+
69
+ self.alpha.requires_grad = alpha_trainable
70
+
71
+ self.no_div_by_zero = 0.000000001
72
+
73
+ def forward(self, x):
74
+ '''
75
+ Forward pass of the function.
76
+ Applies the function to the input elementwise.
77
+ Snake ∶= x + 1/a * sin^2 (xa)
78
+ '''
79
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
80
+ if self.alpha_logscale:
81
+ alpha = torch.exp(alpha)
82
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
83
+
84
+ return x
cosyvoice/transformer/attention.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ # 2022 Xingchen Song ([email protected])
4
+ # 2024 Alibaba Inc (Xiang Lyu)
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Multi-Head Attention layer definition."""
18
+
19
+ import math
20
+ from typing import Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+
25
+
26
+ class MultiHeadedAttention(nn.Module):
27
+ """Multi-Head Attention layer.
28
+
29
+ Args:
30
+ n_head (int): The number of heads.
31
+ n_feat (int): The number of features.
32
+ dropout_rate (float): Dropout rate.
33
+
34
+ """
35
+
36
+ def __init__(self,
37
+ n_head: int,
38
+ n_feat: int,
39
+ dropout_rate: float,
40
+ key_bias: bool = True):
41
+ """Construct an MultiHeadedAttention object."""
42
+ super().__init__()
43
+ assert n_feat % n_head == 0
44
+ # We assume d_v always equals d_k
45
+ self.d_k = n_feat // n_head
46
+ self.h = n_head
47
+ self.linear_q = nn.Linear(n_feat, n_feat)
48
+ self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
49
+ self.linear_v = nn.Linear(n_feat, n_feat)
50
+ self.linear_out = nn.Linear(n_feat, n_feat)
51
+ self.dropout = nn.Dropout(p=dropout_rate)
52
+
53
+ def forward_qkv(
54
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
55
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
56
+ """Transform query, key and value.
57
+
58
+ Args:
59
+ query (torch.Tensor): Query tensor (#batch, time1, size).
60
+ key (torch.Tensor): Key tensor (#batch, time2, size).
61
+ value (torch.Tensor): Value tensor (#batch, time2, size).
62
+
63
+ Returns:
64
+ torch.Tensor: Transformed query tensor, size
65
+ (#batch, n_head, time1, d_k).
66
+ torch.Tensor: Transformed key tensor, size
67
+ (#batch, n_head, time2, d_k).
68
+ torch.Tensor: Transformed value tensor, size
69
+ (#batch, n_head, time2, d_k).
70
+
71
+ """
72
+ n_batch = query.size(0)
73
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
74
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
75
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
76
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
77
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
78
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
79
+
80
+ return q, k, v
81
+
82
+ def forward_attention(
83
+ self,
84
+ value: torch.Tensor,
85
+ scores: torch.Tensor,
86
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
87
+ ) -> torch.Tensor:
88
+ """Compute attention context vector.
89
+
90
+ Args:
91
+ value (torch.Tensor): Transformed value, size
92
+ (#batch, n_head, time2, d_k).
93
+ scores (torch.Tensor): Attention score, size
94
+ (#batch, n_head, time1, time2).
95
+ mask (torch.Tensor): Mask, size (#batch, 1, time2) or
96
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
97
+
98
+ Returns:
99
+ torch.Tensor: Transformed value (#batch, time1, d_model)
100
+ weighted by the attention score (#batch, time1, time2).
101
+
102
+ """
103
+ n_batch = value.size(0)
104
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be True?
105
+ # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
106
+ # 1st chunk to ease the onnx export.]
107
+ # 2. pytorch training
108
+ if mask.size(2) > 0: # time2 > 0
109
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
110
+ # For last chunk, time2 might be larger than scores.size(-1)
111
+ mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
112
+ scores = scores.masked_fill(mask, -float('inf'))
113
+ attn = torch.softmax(scores, dim=-1).masked_fill(
114
+ mask, 0.0) # (batch, head, time1, time2)
115
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be False?
116
+ # 1. onnx(16/-1, -1/-1, 16/0)
117
+ # 2. jit (16/-1, -1/-1, 16/0, 16/4)
118
+ else:
119
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
120
+
121
+ p_attn = self.dropout(attn)
122
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
123
+ x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
124
+ self.h * self.d_k)
125
+ ) # (batch, time1, d_model)
126
+
127
+ return self.linear_out(x) # (batch, time1, d_model)
128
+
129
+ def forward(
130
+ self,
131
+ query: torch.Tensor,
132
+ key: torch.Tensor,
133
+ value: torch.Tensor,
134
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
135
+ pos_emb: torch.Tensor = torch.empty(0),
136
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
137
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
138
+ """Compute scaled dot product attention.
139
+
140
+ Args:
141
+ query (torch.Tensor): Query tensor (#batch, time1, size).
142
+ key (torch.Tensor): Key tensor (#batch, time2, size).
143
+ value (torch.Tensor): Value tensor (#batch, time2, size).
144
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
145
+ (#batch, time1, time2).
146
+ 1.When applying cross attention between decoder and encoder,
147
+ the batch padding mask for input is in (#batch, 1, T) shape.
148
+ 2.When applying self attention of encoder,
149
+ the mask is in (#batch, T, T) shape.
150
+ 3.When applying self attention of decoder,
151
+ the mask is in (#batch, L, L) shape.
152
+ 4.If the different position in decoder see different block
153
+ of the encoder, such as Mocha, the passed in mask could be
154
+ in (#batch, L, T) shape. But there is no such case in current
155
+ CosyVoice.
156
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
157
+ where `cache_t == chunk_size * num_decoding_left_chunks`
158
+ and `head * d_k == size`
159
+
160
+
161
+ Returns:
162
+ torch.Tensor: Output tensor (#batch, time1, d_model).
163
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
164
+ where `cache_t == chunk_size * num_decoding_left_chunks`
165
+ and `head * d_k == size`
166
+
167
+ """
168
+ q, k, v = self.forward_qkv(query, key, value)
169
+
170
+ # NOTE(xcsong):
171
+ # when export onnx model, for 1st chunk, we feed
172
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
173
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
174
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
175
+ # and we will always do splitting and
176
+ # concatnation(this will simplify onnx export). Note that
177
+ # it's OK to concat & split zero-shaped tensors(see code below).
178
+ # when export jit model, for 1st chunk, we always feed
179
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
180
+ # >>> a = torch.ones((1, 2, 0, 4))
181
+ # >>> b = torch.ones((1, 2, 3, 4))
182
+ # >>> c = torch.cat((a, b), dim=2)
183
+ # >>> torch.equal(b, c) # True
184
+ # >>> d = torch.split(a, 2, dim=-1)
185
+ # >>> torch.equal(d[0], d[1]) # True
186
+ if cache.size(0) > 0:
187
+ key_cache, value_cache = torch.split(cache,
188
+ cache.size(-1) // 2,
189
+ dim=-1)
190
+ k = torch.cat([key_cache, k], dim=2)
191
+ v = torch.cat([value_cache, v], dim=2)
192
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
193
+ # non-trivial to calculate `next_cache_start` here.
194
+ new_cache = torch.cat((k, v), dim=-1)
195
+
196
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
197
+ return self.forward_attention(v, scores, mask), new_cache
198
+
199
+
200
+ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
201
+ """Multi-Head Attention layer with relative position encoding.
202
+ Paper: https://arxiv.org/abs/1901.02860
203
+ Args:
204
+ n_head (int): The number of heads.
205
+ n_feat (int): The number of features.
206
+ dropout_rate (float): Dropout rate.
207
+ """
208
+
209
+ def __init__(self,
210
+ n_head: int,
211
+ n_feat: int,
212
+ dropout_rate: float,
213
+ key_bias: bool = True):
214
+ """Construct an RelPositionMultiHeadedAttention object."""
215
+ super().__init__(n_head, n_feat, dropout_rate, key_bias)
216
+ # linear transformation for positional encoding
217
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
218
+ # these two learnable bias are used in matrix c and matrix d
219
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
220
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
221
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
222
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
223
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
224
+
225
+ def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
226
+ """Compute relative positional encoding.
227
+
228
+ Args:
229
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
230
+ time1 means the length of query vector.
231
+
232
+ Returns:
233
+ torch.Tensor: Output tensor.
234
+
235
+ """
236
+ zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
237
+ device=x.device,
238
+ dtype=x.dtype)
239
+ x_padded = torch.cat([zero_pad, x], dim=-1)
240
+
241
+ x_padded = x_padded.view(x.size()[0],
242
+ x.size()[1],
243
+ x.size(3) + 1, x.size(2))
244
+ x = x_padded[:, :, 1:].view_as(x)[
245
+ :, :, :, : x.size(-1) // 2 + 1
246
+ ] # only keep the positions from 0 to time2
247
+ return x
248
+
249
+ def forward(
250
+ self,
251
+ query: torch.Tensor,
252
+ key: torch.Tensor,
253
+ value: torch.Tensor,
254
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
255
+ pos_emb: torch.Tensor = torch.empty(0),
256
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
257
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
258
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
259
+ Args:
260
+ query (torch.Tensor): Query tensor (#batch, time1, size).
261
+ key (torch.Tensor): Key tensor (#batch, time2, size).
262
+ value (torch.Tensor): Value tensor (#batch, time2, size).
263
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
264
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
265
+ pos_emb (torch.Tensor): Positional embedding tensor
266
+ (#batch, time2, size).
267
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
268
+ where `cache_t == chunk_size * num_decoding_left_chunks`
269
+ and `head * d_k == size`
270
+ Returns:
271
+ torch.Tensor: Output tensor (#batch, time1, d_model).
272
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
273
+ where `cache_t == chunk_size * num_decoding_left_chunks`
274
+ and `head * d_k == size`
275
+ """
276
+ q, k, v = self.forward_qkv(query, key, value)
277
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
278
+
279
+ # NOTE(xcsong):
280
+ # when export onnx model, for 1st chunk, we feed
281
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
282
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
283
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
284
+ # and we will always do splitting and
285
+ # concatnation(this will simplify onnx export). Note that
286
+ # it's OK to concat & split zero-shaped tensors(see code below).
287
+ # when export jit model, for 1st chunk, we always feed
288
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
289
+ # >>> a = torch.ones((1, 2, 0, 4))
290
+ # >>> b = torch.ones((1, 2, 3, 4))
291
+ # >>> c = torch.cat((a, b), dim=2)
292
+ # >>> torch.equal(b, c) # True
293
+ # >>> d = torch.split(a, 2, dim=-1)
294
+ # >>> torch.equal(d[0], d[1]) # True
295
+ if cache.size(0) > 0:
296
+ key_cache, value_cache = torch.split(cache,
297
+ cache.size(-1) // 2,
298
+ dim=-1)
299
+ k = torch.cat([key_cache, k], dim=2)
300
+ v = torch.cat([value_cache, v], dim=2)
301
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
302
+ # non-trivial to calculate `next_cache_start` here.
303
+ new_cache = torch.cat((k, v), dim=-1)
304
+
305
+ n_batch_pos = pos_emb.size(0)
306
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
307
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
308
+
309
+ # (batch, head, time1, d_k)
310
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
311
+ # (batch, head, time1, d_k)
312
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
313
+
314
+ # compute attention score
315
+ # first compute matrix a and matrix c
316
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
317
+ # (batch, head, time1, time2)
318
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
319
+
320
+ # compute matrix b and matrix d
321
+ # (batch, head, time1, time2)
322
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
323
+ # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
324
+ if matrix_ac.shape != matrix_bd.shape:
325
+ matrix_bd = self.rel_shift(matrix_bd)
326
+
327
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
328
+ self.d_k) # (batch, head, time1, time2)
329
+
330
+ return self.forward_attention(v, scores, mask), new_cache
cosyvoice/transformer/convolution.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """ConvolutionModule definition."""
17
+
18
+ from typing import Tuple
19
+
20
+ import torch
21
+ from torch import nn
22
+ import torch.nn.functional as F
23
+
24
+
25
+ class ConvolutionModule(nn.Module):
26
+ """ConvolutionModule in Conformer model."""
27
+
28
+ def __init__(self,
29
+ channels: int,
30
+ kernel_size: int = 15,
31
+ activation: nn.Module = nn.ReLU(),
32
+ norm: str = "batch_norm",
33
+ causal: bool = False,
34
+ bias: bool = True):
35
+ """Construct an ConvolutionModule object.
36
+ Args:
37
+ channels (int): The number of channels of conv layers.
38
+ kernel_size (int): Kernel size of conv layers.
39
+ causal (int): Whether use causal convolution or not
40
+ """
41
+ super().__init__()
42
+
43
+ self.pointwise_conv1 = nn.Conv1d(
44
+ channels,
45
+ 2 * channels,
46
+ kernel_size=1,
47
+ stride=1,
48
+ padding=0,
49
+ bias=bias,
50
+ )
51
+ # self.lorder is used to distinguish if it's a causal convolution,
52
+ # if self.lorder > 0: it's a causal convolution, the input will be
53
+ # padded with self.lorder frames on the left in forward.
54
+ # else: it's a symmetrical convolution
55
+ if causal:
56
+ padding = 0
57
+ self.lorder = kernel_size - 1
58
+ else:
59
+ # kernel_size should be an odd number for none causal convolution
60
+ assert (kernel_size - 1) % 2 == 0
61
+ padding = (kernel_size - 1) // 2
62
+ self.lorder = 0
63
+ self.depthwise_conv = nn.Conv1d(
64
+ channels,
65
+ channels,
66
+ kernel_size,
67
+ stride=1,
68
+ padding=padding,
69
+ groups=channels,
70
+ bias=bias,
71
+ )
72
+
73
+ assert norm in ['batch_norm', 'layer_norm']
74
+ if norm == "batch_norm":
75
+ self.use_layer_norm = False
76
+ self.norm = nn.BatchNorm1d(channels)
77
+ else:
78
+ self.use_layer_norm = True
79
+ self.norm = nn.LayerNorm(channels)
80
+
81
+ self.pointwise_conv2 = nn.Conv1d(
82
+ channels,
83
+ channels,
84
+ kernel_size=1,
85
+ stride=1,
86
+ padding=0,
87
+ bias=bias,
88
+ )
89
+ self.activation = activation
90
+
91
+ def forward(
92
+ self,
93
+ x: torch.Tensor,
94
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
95
+ cache: torch.Tensor = torch.zeros((0, 0, 0)),
96
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
97
+ """Compute convolution module.
98
+ Args:
99
+ x (torch.Tensor): Input tensor (#batch, time, channels).
100
+ mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
101
+ (0, 0, 0) means fake mask.
102
+ cache (torch.Tensor): left context cache, it is only
103
+ used in causal convolution (#batch, channels, cache_t),
104
+ (0, 0, 0) meas fake cache.
105
+ Returns:
106
+ torch.Tensor: Output tensor (#batch, time, channels).
107
+ """
108
+ # exchange the temporal dimension and the feature dimension
109
+ x = x.transpose(1, 2) # (#batch, channels, time)
110
+
111
+ # mask batch padding
112
+ if mask_pad.size(2) > 0: # time > 0
113
+ x.masked_fill_(~mask_pad, 0.0)
114
+
115
+ if self.lorder > 0:
116
+ if cache.size(2) == 0: # cache_t == 0
117
+ x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
118
+ else:
119
+ assert cache.size(0) == x.size(0) # equal batch
120
+ assert cache.size(1) == x.size(1) # equal channel
121
+ x = torch.cat((cache, x), dim=2)
122
+ assert (x.size(2) > self.lorder)
123
+ new_cache = x[:, :, -self.lorder:]
124
+ else:
125
+ # It's better we just return None if no cache is required,
126
+ # However, for JIT export, here we just fake one tensor instead of
127
+ # None.
128
+ new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
129
+
130
+ # GLU mechanism
131
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
132
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
133
+
134
+ # 1D Depthwise Conv
135
+ x = self.depthwise_conv(x)
136
+ if self.use_layer_norm:
137
+ x = x.transpose(1, 2)
138
+ x = self.activation(self.norm(x))
139
+ if self.use_layer_norm:
140
+ x = x.transpose(1, 2)
141
+ x = self.pointwise_conv2(x)
142
+ # mask batch padding
143
+ if mask_pad.size(2) > 0: # time > 0
144
+ x.masked_fill_(~mask_pad, 0.0)
145
+
146
+ return x.transpose(1, 2), new_cache
147
+
148
+
149
+ # NOTE(Xiang Lyu) causal conv module used in convolution-based vocoder
150
+ class CausalConv1d(torch.nn.Conv1d):
151
+ def __init__(
152
+ self,
153
+ in_channels: int,
154
+ out_channels: int,
155
+ kernel_size: int,
156
+ stride: int = 1,
157
+ dilation: int = 1,
158
+ groups: int = 1,
159
+ bias: bool = True,
160
+ padding_mode: str = 'zeros',
161
+ causal_type: str = 'left',
162
+ device=None,
163
+ dtype=None
164
+ ) -> None:
165
+ super(CausalConv1d, self).__init__(in_channels, out_channels,
166
+ kernel_size, stride=1,
167
+ padding=0, dilation=dilation,
168
+ groups=groups, bias=bias,
169
+ padding_mode=padding_mode,
170
+ device=device, dtype=dtype)
171
+ assert stride == 1
172
+ self.causal_padding = int((kernel_size * dilation - dilation) / 2) * 2 + (kernel_size + 1) % 2
173
+ assert causal_type in ['left', 'right']
174
+ self.causal_type = causal_type
175
+
176
+ def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor]:
177
+ input_timestep = x.shape[2]
178
+ if cache.size(2) == 0:
179
+ cache = torch.zeros(x.shape[0], x.shape[1], self.causal_padding).to(x)
180
+ assert cache.size(2) == self.causal_padding
181
+ if self.causal_type == 'left':
182
+ x = torch.concat([cache, x], dim=2)
183
+ else:
184
+ x = torch.concat([x, cache], dim=2)
185
+ x = super(CausalConv1d, self).forward(x)
186
+ assert x.shape[2] == input_timestep
187
+ return x
188
+
189
+
190
+ class CausalConv1dDownSample(torch.nn.Conv1d):
191
+ def __init__(
192
+ self,
193
+ in_channels: int,
194
+ out_channels: int,
195
+ kernel_size: int,
196
+ stride: int = 1,
197
+ dilation: int = 1,
198
+ groups: int = 1,
199
+ bias: bool = True,
200
+ padding_mode: str = 'zeros',
201
+ device=None,
202
+ dtype=None
203
+ ) -> None:
204
+ super(CausalConv1dDownSample, self).__init__(in_channels, out_channels,
205
+ kernel_size, stride,
206
+ padding=0, dilation=dilation,
207
+ groups=groups, bias=bias,
208
+ padding_mode=padding_mode,
209
+ device=device, dtype=dtype)
210
+ assert stride != 1 and dilation == 1
211
+ assert kernel_size % stride == 0
212
+ self.causal_padding = stride - 1
213
+
214
+ def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
215
+ if cache.size(2) == 0:
216
+ x = F.pad(x, (self.causal_padding, 0), value=0.0)
217
+ else:
218
+ assert cache.size(2) == self.causal_padding
219
+ x = torch.concat([cache, x], dim=2)
220
+ x = super(CausalConv1dDownSample, self).forward(x)
221
+ return x
222
+
223
+
224
+ class CausalConv1dUpsample(torch.nn.Conv1d):
225
+ def __init__(
226
+ self,
227
+ in_channels: int,
228
+ out_channels: int,
229
+ kernel_size: int,
230
+ stride: int = 1,
231
+ dilation: int = 1,
232
+ groups: int = 1,
233
+ bias: bool = True,
234
+ padding_mode: str = 'zeros',
235
+ device=None,
236
+ dtype=None
237
+ ) -> None:
238
+ super(CausalConv1dUpsample, self).__init__(in_channels, out_channels,
239
+ kernel_size, 1,
240
+ padding=0, dilation=dilation,
241
+ groups=groups, bias=bias,
242
+ padding_mode=padding_mode,
243
+ device=device, dtype=dtype)
244
+ assert dilation == 1
245
+ self.causal_padding = kernel_size - 1
246
+ self.upsample = torch.nn.Upsample(scale_factor=stride, mode='nearest')
247
+
248
+ def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
249
+ x = self.upsample(x)
250
+ input_timestep = x.shape[2]
251
+ if cache.size(2) == 0:
252
+ x = F.pad(x, (self.causal_padding, 0), value=0.0)
253
+ else:
254
+ assert cache.size(2) == self.causal_padding
255
+ x = torch.concat([cache, x], dim=2)
256
+ x = super(CausalConv1dUpsample, self).forward(x)
257
+ assert input_timestep == x.shape[2]
258
+ return x
cosyvoice/transformer/decoder.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Decoder definition."""
17
+ from typing import Tuple, List, Optional
18
+
19
+ import torch
20
+ import torch.utils.checkpoint as ckpt
21
+ import logging
22
+
23
+ from cosyvoice.transformer.decoder_layer import DecoderLayer
24
+ from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
25
+ from cosyvoice.utils.class_utils import (
26
+ COSYVOICE_EMB_CLASSES,
27
+ COSYVOICE_ATTENTION_CLASSES,
28
+ COSYVOICE_ACTIVATION_CLASSES,
29
+ )
30
+ from cosyvoice.utils.mask import (subsequent_mask, make_pad_mask)
31
+
32
+
33
+ class TransformerDecoder(torch.nn.Module):
34
+ """Base class of Transfomer decoder module.
35
+ Args:
36
+ vocab_size: output dim
37
+ encoder_output_size: dimension of attention
38
+ attention_heads: the number of heads of multi head attention
39
+ linear_units: the hidden units number of position-wise feedforward
40
+ num_blocks: the number of decoder blocks
41
+ dropout_rate: dropout rate
42
+ self_attention_dropout_rate: dropout rate for attention
43
+ input_layer: input layer type
44
+ use_output_layer: whether to use output layer
45
+ pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
46
+ normalize_before:
47
+ True: use layer_norm before each sub-block of a layer.
48
+ False: use layer_norm after each sub-block of a layer.
49
+ src_attention: if false, encoder-decoder cross attention is not
50
+ applied, such as CIF model
51
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
52
+ gradient_checkpointing: rerunning a forward-pass segment for each
53
+ checkpointed segment during backward.
54
+ tie_word_embedding: Tie or clone module weights depending of whether we are
55
+ using TorchScript or not
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ vocab_size: int,
61
+ encoder_output_size: int,
62
+ attention_heads: int = 4,
63
+ linear_units: int = 2048,
64
+ num_blocks: int = 6,
65
+ dropout_rate: float = 0.1,
66
+ positional_dropout_rate: float = 0.1,
67
+ self_attention_dropout_rate: float = 0.0,
68
+ src_attention_dropout_rate: float = 0.0,
69
+ input_layer: str = "embed",
70
+ use_output_layer: bool = True,
71
+ normalize_before: bool = True,
72
+ src_attention: bool = True,
73
+ key_bias: bool = True,
74
+ activation_type: str = "relu",
75
+ gradient_checkpointing: bool = False,
76
+ tie_word_embedding: bool = False,
77
+ ):
78
+ super().__init__()
79
+ attention_dim = encoder_output_size
80
+ activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
81
+
82
+ self.embed = torch.nn.Sequential(
83
+ torch.nn.Identity() if input_layer == "no_pos" else
84
+ torch.nn.Embedding(vocab_size, attention_dim),
85
+ COSYVOICE_EMB_CLASSES[input_layer](attention_dim,
86
+ positional_dropout_rate),
87
+ )
88
+
89
+ self.normalize_before = normalize_before
90
+ self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5)
91
+ self.use_output_layer = use_output_layer
92
+ if use_output_layer:
93
+ self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
94
+ else:
95
+ self.output_layer = torch.nn.Identity()
96
+ self.num_blocks = num_blocks
97
+ self.decoders = torch.nn.ModuleList([
98
+ DecoderLayer(
99
+ attention_dim,
100
+ COSYVOICE_ATTENTION_CLASSES["selfattn"](
101
+ attention_heads, attention_dim,
102
+ self_attention_dropout_rate, key_bias),
103
+ COSYVOICE_ATTENTION_CLASSES["selfattn"](
104
+ attention_heads, attention_dim, src_attention_dropout_rate,
105
+ key_bias) if src_attention else None,
106
+ PositionwiseFeedForward(attention_dim, linear_units,
107
+ dropout_rate, activation),
108
+ dropout_rate,
109
+ normalize_before,
110
+ ) for _ in range(self.num_blocks)
111
+ ])
112
+
113
+ self.gradient_checkpointing = gradient_checkpointing
114
+ self.tie_word_embedding = tie_word_embedding
115
+
116
+ def forward(
117
+ self,
118
+ memory: torch.Tensor,
119
+ memory_mask: torch.Tensor,
120
+ ys_in_pad: torch.Tensor,
121
+ ys_in_lens: torch.Tensor,
122
+ r_ys_in_pad: torch.Tensor = torch.empty(0),
123
+ reverse_weight: float = 0.0,
124
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
125
+ """Forward decoder.
126
+ Args:
127
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
128
+ memory_mask: encoder memory mask, (batch, 1, maxlen_in)
129
+ ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
130
+ ys_in_lens: input lengths of this batch (batch)
131
+ r_ys_in_pad: not used in transformer decoder, in order to unify api
132
+ with bidirectional decoder
133
+ reverse_weight: not used in transformer decoder, in order to unify
134
+ api with bidirectional decode
135
+ Returns:
136
+ (tuple): tuple containing:
137
+ x: decoded token score before softmax (batch, maxlen_out,
138
+ vocab_size) if use_output_layer is True,
139
+ torch.tensor(0.0), in order to unify api with bidirectional decoder
140
+ olens: (batch, )
141
+ NOTE(xcsong):
142
+ We pass the `__call__` method of the modules instead of `forward` to the
143
+ checkpointing API because `__call__` attaches all the hooks of the module.
144
+ https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
145
+ """
146
+ tgt = ys_in_pad
147
+ maxlen = tgt.size(1)
148
+ # tgt_mask: (B, 1, L)
149
+ tgt_mask = ~make_pad_mask(ys_in_lens, maxlen).unsqueeze(1)
150
+ tgt_mask = tgt_mask.to(tgt.device)
151
+ # m: (1, L, L)
152
+ m = subsequent_mask(tgt_mask.size(-1),
153
+ device=tgt_mask.device).unsqueeze(0)
154
+ # tgt_mask: (B, L, L)
155
+ tgt_mask = tgt_mask & m
156
+ x, _ = self.embed(tgt)
157
+ if self.gradient_checkpointing and self.training:
158
+ x = self.forward_layers_checkpointed(x, tgt_mask, memory,
159
+ memory_mask)
160
+ else:
161
+ x = self.forward_layers(x, tgt_mask, memory, memory_mask)
162
+ if self.normalize_before:
163
+ x = self.after_norm(x)
164
+ if self.use_output_layer:
165
+ x = self.output_layer(x)
166
+ olens = tgt_mask.sum(1)
167
+ return x, torch.tensor(0.0), olens
168
+
169
+ def forward_layers(self, x: torch.Tensor, tgt_mask: torch.Tensor,
170
+ memory: torch.Tensor,
171
+ memory_mask: torch.Tensor) -> torch.Tensor:
172
+ for layer in self.decoders:
173
+ x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory,
174
+ memory_mask)
175
+ return x
176
+
177
+ @torch.jit.unused
178
+ def forward_layers_checkpointed(self, x: torch.Tensor,
179
+ tgt_mask: torch.Tensor,
180
+ memory: torch.Tensor,
181
+ memory_mask: torch.Tensor) -> torch.Tensor:
182
+ for layer in self.decoders:
183
+ x, tgt_mask, memory, memory_mask = ckpt.checkpoint(
184
+ layer.__call__, x, tgt_mask, memory, memory_mask)
185
+ return x
186
+
187
+ def forward_one_step(
188
+ self,
189
+ memory: torch.Tensor,
190
+ memory_mask: torch.Tensor,
191
+ tgt: torch.Tensor,
192
+ tgt_mask: torch.Tensor,
193
+ cache: Optional[List[torch.Tensor]] = None,
194
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
195
+ """Forward one step.
196
+ This is only used for decoding.
197
+ Args:
198
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
199
+ memory_mask: encoded memory mask, (batch, 1, maxlen_in)
200
+ tgt: input token ids, int64 (batch, maxlen_out)
201
+ tgt_mask: input token mask, (batch, maxlen_out)
202
+ dtype=torch.uint8 in PyTorch 1.2-
203
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
204
+ cache: cached output list of (batch, max_time_out-1, size)
205
+ Returns:
206
+ y, cache: NN output value and cache per `self.decoders`.
207
+ y.shape` is (batch, maxlen_out, token)
208
+ """
209
+ x, _ = self.embed(tgt)
210
+ new_cache = []
211
+ for i, decoder in enumerate(self.decoders):
212
+ if cache is None:
213
+ c = None
214
+ else:
215
+ c = cache[i]
216
+ x, tgt_mask, memory, memory_mask = decoder(x,
217
+ tgt_mask,
218
+ memory,
219
+ memory_mask,
220
+ cache=c)
221
+ new_cache.append(x)
222
+ if self.normalize_before:
223
+ y = self.after_norm(x[:, -1])
224
+ else:
225
+ y = x[:, -1]
226
+ if self.use_output_layer:
227
+ y = torch.log_softmax(self.output_layer(y), dim=-1)
228
+ return y, new_cache
229
+
230
+ def tie_or_clone_weights(self, jit_mode: bool = True):
231
+ """Tie or clone module weights (between word_emb and output_layer)
232
+ depending of whether we are using TorchScript or not"""
233
+ if not self.use_output_layer:
234
+ return
235
+ if jit_mode:
236
+ logging.info("clone emb.weight to output.weight")
237
+ self.output_layer.weight = torch.nn.Parameter(
238
+ self.embed[0].weight.clone())
239
+ else:
240
+ logging.info("tie emb.weight with output.weight")
241
+ self.output_layer.weight = self.embed[0].weight
242
+
243
+ if getattr(self.output_layer, "bias", None) is not None:
244
+ self.output_layer.bias.data = torch.nn.functional.pad(
245
+ self.output_layer.bias.data,
246
+ (
247
+ 0,
248
+ self.output_layer.weight.shape[0] -
249
+ self.output_layer.bias.shape[0],
250
+ ),
251
+ "constant",
252
+ 0,
253
+ )
254
+
255
+
256
+ class BiTransformerDecoder(torch.nn.Module):
257
+ """Base class of Transfomer decoder module.
258
+ Args:
259
+ vocab_size: output dim
260
+ encoder_output_size: dimension of attention
261
+ attention_heads: the number of heads of multi head attention
262
+ linear_units: the hidden units number of position-wise feedforward
263
+ num_blocks: the number of decoder blocks
264
+ r_num_blocks: the number of right to left decoder blocks
265
+ dropout_rate: dropout rate
266
+ self_attention_dropout_rate: dropout rate for attention
267
+ input_layer: input layer type
268
+ use_output_layer: whether to use output layer
269
+ pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
270
+ normalize_before:
271
+ True: use layer_norm before each sub-block of a layer.
272
+ False: use layer_norm after each sub-block of a layer.
273
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
274
+ """
275
+
276
+ def __init__(
277
+ self,
278
+ vocab_size: int,
279
+ encoder_output_size: int,
280
+ attention_heads: int = 4,
281
+ linear_units: int = 2048,
282
+ num_blocks: int = 6,
283
+ r_num_blocks: int = 0,
284
+ dropout_rate: float = 0.1,
285
+ positional_dropout_rate: float = 0.1,
286
+ self_attention_dropout_rate: float = 0.0,
287
+ src_attention_dropout_rate: float = 0.0,
288
+ input_layer: str = "embed",
289
+ use_output_layer: bool = True,
290
+ normalize_before: bool = True,
291
+ key_bias: bool = True,
292
+ gradient_checkpointing: bool = False,
293
+ tie_word_embedding: bool = False,
294
+ ):
295
+
296
+ super().__init__()
297
+ self.tie_word_embedding = tie_word_embedding
298
+ self.left_decoder = TransformerDecoder(
299
+ vocab_size,
300
+ encoder_output_size,
301
+ attention_heads,
302
+ linear_units,
303
+ num_blocks,
304
+ dropout_rate,
305
+ positional_dropout_rate,
306
+ self_attention_dropout_rate,
307
+ src_attention_dropout_rate,
308
+ input_layer,
309
+ use_output_layer,
310
+ normalize_before,
311
+ key_bias=key_bias,
312
+ gradient_checkpointing=gradient_checkpointing,
313
+ tie_word_embedding=tie_word_embedding)
314
+
315
+ self.right_decoder = TransformerDecoder(
316
+ vocab_size,
317
+ encoder_output_size,
318
+ attention_heads,
319
+ linear_units,
320
+ r_num_blocks,
321
+ dropout_rate,
322
+ positional_dropout_rate,
323
+ self_attention_dropout_rate,
324
+ src_attention_dropout_rate,
325
+ input_layer,
326
+ use_output_layer,
327
+ normalize_before,
328
+ key_bias=key_bias,
329
+ gradient_checkpointing=gradient_checkpointing,
330
+ tie_word_embedding=tie_word_embedding)
331
+
332
+ def forward(
333
+ self,
334
+ memory: torch.Tensor,
335
+ memory_mask: torch.Tensor,
336
+ ys_in_pad: torch.Tensor,
337
+ ys_in_lens: torch.Tensor,
338
+ r_ys_in_pad: torch.Tensor,
339
+ reverse_weight: float = 0.0,
340
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
341
+ """Forward decoder.
342
+ Args:
343
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
344
+ memory_mask: encoder memory mask, (batch, 1, maxlen_in)
345
+ ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
346
+ ys_in_lens: input lengths of this batch (batch)
347
+ r_ys_in_pad: padded input token ids, int64 (batch, maxlen_out),
348
+ used for right to left decoder
349
+ reverse_weight: used for right to left decoder
350
+ Returns:
351
+ (tuple): tuple containing:
352
+ x: decoded token score before softmax (batch, maxlen_out,
353
+ vocab_size) if use_output_layer is True,
354
+ r_x: x: decoded token score (right to left decoder)
355
+ before softmax (batch, maxlen_out, vocab_size)
356
+ if use_output_layer is True,
357
+ olens: (batch, )
358
+ """
359
+ l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad,
360
+ ys_in_lens)
361
+ r_x = torch.tensor(0.0)
362
+ if reverse_weight > 0.0:
363
+ r_x, _, olens = self.right_decoder(memory, memory_mask,
364
+ r_ys_in_pad, ys_in_lens)
365
+ return l_x, r_x, olens
366
+
367
+ def forward_one_step(
368
+ self,
369
+ memory: torch.Tensor,
370
+ memory_mask: torch.Tensor,
371
+ tgt: torch.Tensor,
372
+ tgt_mask: torch.Tensor,
373
+ cache: Optional[List[torch.Tensor]] = None,
374
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
375
+ """Forward one step.
376
+ This is only used for decoding.
377
+ Args:
378
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
379
+ memory_mask: encoded memory mask, (batch, 1, maxlen_in)
380
+ tgt: input token ids, int64 (batch, maxlen_out)
381
+ tgt_mask: input token mask, (batch, maxlen_out)
382
+ dtype=torch.uint8 in PyTorch 1.2-
383
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
384
+ cache: cached output list of (batch, max_time_out-1, size)
385
+ Returns:
386
+ y, cache: NN output value and cache per `self.decoders`.
387
+ y.shape` is (batch, maxlen_out, token)
388
+ """
389
+ return self.left_decoder.forward_one_step(memory, memory_mask, tgt,
390
+ tgt_mask, cache)
391
+
392
+ def tie_or_clone_weights(self, jit_mode: bool = True):
393
+ """Tie or clone module weights (between word_emb and output_layer)
394
+ depending of whether we are using TorchScript or not"""
395
+ self.left_decoder.tie_or_clone_weights(jit_mode)
396
+ self.right_decoder.tie_or_clone_weights(jit_mode)
cosyvoice/transformer/decoder_layer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Decoder self-attention layer definition."""
16
+ from typing import Optional, Tuple
17
+
18
+ import torch
19
+ from torch import nn
20
+
21
+
22
+ class DecoderLayer(nn.Module):
23
+ """Single decoder layer module.
24
+
25
+ Args:
26
+ size (int): Input dimension.
27
+ self_attn (torch.nn.Module): Self-attention module instance.
28
+ `MultiHeadedAttention` instance can be used as the argument.
29
+ src_attn (torch.nn.Module): Inter-attention module instance.
30
+ `MultiHeadedAttention` instance can be used as the argument.
31
+ If `None` is passed, Inter-attention is not used, such as
32
+ CIF, GPT, and other decoder only model.
33
+ feed_forward (torch.nn.Module): Feed-forward module instance.
34
+ `PositionwiseFeedForward` instance can be used as the argument.
35
+ dropout_rate (float): Dropout rate.
36
+ normalize_before (bool):
37
+ True: use layer_norm before each sub-block.
38
+ False: to use layer_norm after each sub-block.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ size: int,
44
+ self_attn: nn.Module,
45
+ src_attn: Optional[nn.Module],
46
+ feed_forward: nn.Module,
47
+ dropout_rate: float,
48
+ normalize_before: bool = True,
49
+ ):
50
+ """Construct an DecoderLayer object."""
51
+ super().__init__()
52
+ self.size = size
53
+ self.self_attn = self_attn
54
+ self.src_attn = src_attn
55
+ self.feed_forward = feed_forward
56
+ self.norm1 = nn.LayerNorm(size, eps=1e-5)
57
+ self.norm2 = nn.LayerNorm(size, eps=1e-5)
58
+ self.norm3 = nn.LayerNorm(size, eps=1e-5)
59
+ self.dropout = nn.Dropout(dropout_rate)
60
+ self.normalize_before = normalize_before
61
+
62
+ def forward(
63
+ self,
64
+ tgt: torch.Tensor,
65
+ tgt_mask: torch.Tensor,
66
+ memory: torch.Tensor,
67
+ memory_mask: torch.Tensor,
68
+ cache: Optional[torch.Tensor] = None
69
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
70
+ """Compute decoded features.
71
+
72
+ Args:
73
+ tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
74
+ tgt_mask (torch.Tensor): Mask for input tensor
75
+ (#batch, maxlen_out).
76
+ memory (torch.Tensor): Encoded memory
77
+ (#batch, maxlen_in, size).
78
+ memory_mask (torch.Tensor): Encoded memory mask
79
+ (#batch, maxlen_in).
80
+ cache (torch.Tensor): cached tensors.
81
+ (#batch, maxlen_out - 1, size).
82
+
83
+ Returns:
84
+ torch.Tensor: Output tensor (#batch, maxlen_out, size).
85
+ torch.Tensor: Mask for output tensor (#batch, maxlen_out).
86
+ torch.Tensor: Encoded memory (#batch, maxlen_in, size).
87
+ torch.Tensor: Encoded memory mask (#batch, maxlen_in).
88
+
89
+ """
90
+ residual = tgt
91
+ if self.normalize_before:
92
+ tgt = self.norm1(tgt)
93
+
94
+ if cache is None:
95
+ tgt_q = tgt
96
+ tgt_q_mask = tgt_mask
97
+ else:
98
+ # compute only the last frame query keeping dim: max_time_out -> 1
99
+ assert cache.shape == (
100
+ tgt.shape[0],
101
+ tgt.shape[1] - 1,
102
+ self.size,
103
+ ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
104
+ tgt_q = tgt[:, -1:, :]
105
+ residual = residual[:, -1:, :]
106
+ tgt_q_mask = tgt_mask[:, -1:, :]
107
+
108
+ x = residual + self.dropout(
109
+ self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
110
+ if not self.normalize_before:
111
+ x = self.norm1(x)
112
+
113
+ if self.src_attn is not None:
114
+ residual = x
115
+ if self.normalize_before:
116
+ x = self.norm2(x)
117
+ x = residual + self.dropout(
118
+ self.src_attn(x, memory, memory, memory_mask)[0])
119
+ if not self.normalize_before:
120
+ x = self.norm2(x)
121
+
122
+ residual = x
123
+ if self.normalize_before:
124
+ x = self.norm3(x)
125
+ x = residual + self.dropout(self.feed_forward(x))
126
+ if not self.normalize_before:
127
+ x = self.norm3(x)
128
+
129
+ if cache is not None:
130
+ x = torch.cat([cache, x], dim=1)
131
+
132
+ return x, tgt_mask, memory, memory_mask
cosyvoice/transformer/embedding.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Positonal Encoding Module."""
17
+
18
+ import math
19
+ from typing import Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ import numpy as np
24
+
25
+
26
+ class PositionalEncoding(torch.nn.Module):
27
+ """Positional encoding.
28
+
29
+ :param int d_model: embedding dim
30
+ :param float dropout_rate: dropout rate
31
+ :param int max_len: maximum input length
32
+
33
+ PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
34
+ PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
35
+ """
36
+
37
+ def __init__(self,
38
+ d_model: int,
39
+ dropout_rate: float,
40
+ max_len: int = 5000,
41
+ reverse: bool = False):
42
+ """Construct an PositionalEncoding object."""
43
+ super().__init__()
44
+ self.d_model = d_model
45
+ self.xscale = math.sqrt(self.d_model)
46
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
47
+ self.max_len = max_len
48
+
49
+ self.pe = torch.zeros(self.max_len, self.d_model)
50
+ position = torch.arange(0, self.max_len,
51
+ dtype=torch.float32).unsqueeze(1)
52
+ div_term = torch.exp(
53
+ torch.arange(0, self.d_model, 2, dtype=torch.float32) *
54
+ -(math.log(10000.0) / self.d_model))
55
+ self.pe[:, 0::2] = torch.sin(position * div_term)
56
+ self.pe[:, 1::2] = torch.cos(position * div_term)
57
+ self.pe = self.pe.unsqueeze(0)
58
+
59
+ def forward(self,
60
+ x: torch.Tensor,
61
+ offset: Union[int, torch.Tensor] = 0) \
62
+ -> Tuple[torch.Tensor, torch.Tensor]:
63
+ """Add positional encoding.
64
+
65
+ Args:
66
+ x (torch.Tensor): Input. Its shape is (batch, time, ...)
67
+ offset (int, torch.tensor): position offset
68
+
69
+ Returns:
70
+ torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
71
+ torch.Tensor: for compatibility to RelPositionalEncoding
72
+ """
73
+
74
+ self.pe = self.pe.to(x.device)
75
+ pos_emb = self.position_encoding(offset, x.size(1), False)
76
+ x = x * self.xscale + pos_emb
77
+ return self.dropout(x), self.dropout(pos_emb)
78
+
79
+ def position_encoding(self,
80
+ offset: Union[int, torch.Tensor],
81
+ size: int,
82
+ apply_dropout: bool = True) -> torch.Tensor:
83
+ """ For getting encoding in a streaming fashion
84
+
85
+ Attention!!!!!
86
+ we apply dropout only once at the whole utterance level in a none
87
+ streaming way, but will call this function several times with
88
+ increasing input size in a streaming scenario, so the dropout will
89
+ be applied several times.
90
+
91
+ Args:
92
+ offset (int or torch.tensor): start offset
93
+ size (int): required size of position encoding
94
+
95
+ Returns:
96
+ torch.Tensor: Corresponding encoding
97
+ """
98
+ # How to subscript a Union type:
99
+ # https://github.com/pytorch/pytorch/issues/69434
100
+ if isinstance(offset, int):
101
+ assert offset + size <= self.max_len
102
+ pos_emb = self.pe[:, offset:offset + size]
103
+ elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
104
+ assert offset + size <= self.max_len
105
+ pos_emb = self.pe[:, offset:offset + size]
106
+ else: # for batched streaming decoding on GPU
107
+ assert torch.max(offset) + size <= self.max_len
108
+ index = offset.unsqueeze(1) + \
109
+ torch.arange(0, size).to(offset.device) # B X T
110
+ flag = index > 0
111
+ # remove negative offset
112
+ index = index * flag
113
+ pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model
114
+
115
+ if apply_dropout:
116
+ pos_emb = self.dropout(pos_emb)
117
+ return pos_emb
118
+
119
+
120
+ class RelPositionalEncoding(PositionalEncoding):
121
+ """Relative positional encoding module.
122
+ See : Appendix B in https://arxiv.org/abs/1901.02860
123
+ Args:
124
+ d_model (int): Embedding dimension.
125
+ dropout_rate (float): Dropout rate.
126
+ max_len (int): Maximum input length.
127
+ """
128
+
129
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
130
+ """Initialize class."""
131
+ super().__init__(d_model, dropout_rate, max_len, reverse=True)
132
+
133
+ def forward(self,
134
+ x: torch.Tensor,
135
+ offset: Union[int, torch.Tensor] = 0) \
136
+ -> Tuple[torch.Tensor, torch.Tensor]:
137
+ """Compute positional encoding.
138
+ Args:
139
+ x (torch.Tensor): Input tensor (batch, time, `*`).
140
+ Returns:
141
+ torch.Tensor: Encoded tensor (batch, time, `*`).
142
+ torch.Tensor: Positional embedding tensor (1, time, `*`).
143
+ """
144
+ self.pe = self.pe.to(x.device)
145
+ x = x * self.xscale
146
+ pos_emb = self.position_encoding(offset, x.size(1), False)
147
+ return self.dropout(x), self.dropout(pos_emb)
148
+
149
+
150
+ class WhisperPositionalEncoding(PositionalEncoding):
151
+ """ Sinusoids position encoding used in openai-whisper.encoder
152
+ """
153
+
154
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500):
155
+ super().__init__(d_model, dropout_rate, max_len)
156
+ self.xscale = 1.0
157
+ log_timescale_increment = np.log(10000) / (d_model // 2 - 1)
158
+ inv_timescales = torch.exp(-log_timescale_increment *
159
+ torch.arange(d_model // 2))
160
+ scaled_time = torch.arange(max_len)[:, np.newaxis] * \
161
+ inv_timescales[np.newaxis, :]
162
+ pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
163
+ delattr(self, "pe")
164
+ self.register_buffer("pe", pe.unsqueeze(0))
165
+
166
+
167
+ class LearnablePositionalEncoding(PositionalEncoding):
168
+ """ Learnable position encoding used in openai-whisper.decoder
169
+ """
170
+
171
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448):
172
+ super().__init__(d_model, dropout_rate, max_len)
173
+ # NOTE(xcsong): overwrite self.pe & self.xscale
174
+ self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model))
175
+ self.xscale = 1.0
176
+
177
+
178
+ class NoPositionalEncoding(torch.nn.Module):
179
+ """ No position encoding
180
+ """
181
+
182
+ def __init__(self, d_model: int, dropout_rate: float):
183
+ super().__init__()
184
+ self.d_model = d_model
185
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
186
+
187
+ def forward(self,
188
+ x: torch.Tensor,
189
+ offset: Union[int, torch.Tensor] = 0) \
190
+ -> Tuple[torch.Tensor, torch.Tensor]:
191
+ """ Just return zero vector for interface compatibility
192
+ """
193
+ pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
194
+ return self.dropout(x), pos_emb
195
+
196
+ def position_encoding(self, offset: Union[int, torch.Tensor],
197
+ size: int) -> torch.Tensor:
198
+ return torch.zeros(1, size, self.d_model)
199
+
200
+
201
+ class EspnetRelPositionalEncoding(torch.nn.Module):
202
+ """Relative positional encoding module (new implementation).
203
+
204
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
205
+
206
+ See : Appendix B in https://arxiv.org/abs/1901.02860
207
+
208
+ Args:
209
+ d_model (int): Embedding dimension.
210
+ dropout_rate (float): Dropout rate.
211
+ max_len (int): Maximum input length.
212
+
213
+ """
214
+
215
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
216
+ """Construct an PositionalEncoding object."""
217
+ super(EspnetRelPositionalEncoding, self).__init__()
218
+ self.d_model = d_model
219
+ self.xscale = math.sqrt(self.d_model)
220
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
221
+ self.pe = None
222
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
223
+
224
+ def extend_pe(self, x: torch.Tensor):
225
+ """Reset the positional encodings."""
226
+ if self.pe is not None:
227
+ # self.pe contains both positive and negative parts
228
+ # the length of self.pe is 2 * input_len - 1
229
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
230
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
231
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
232
+ return
233
+ # Suppose `i` means to the position of query vecotr and `j` means the
234
+ # position of key vector. We use position relative positions when keys
235
+ # are to the left (i>j) and negative relative positions otherwise (i<j).
236
+ pe_positive = torch.zeros(x.size(1), self.d_model)
237
+ pe_negative = torch.zeros(x.size(1), self.d_model)
238
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
239
+ div_term = torch.exp(
240
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
241
+ * -(math.log(10000.0) / self.d_model)
242
+ )
243
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
244
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
245
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
246
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
247
+
248
+ # Reserve the order of positive indices and concat both positive and
249
+ # negative indices. This is used to support the shifting trick
250
+ # as in https://arxiv.org/abs/1901.02860
251
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
252
+ pe_negative = pe_negative[1:].unsqueeze(0)
253
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
254
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
255
+
256
+ def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
257
+ -> Tuple[torch.Tensor, torch.Tensor]:
258
+ """Add positional encoding.
259
+
260
+ Args:
261
+ x (torch.Tensor): Input tensor (batch, time, `*`).
262
+
263
+ Returns:
264
+ torch.Tensor: Encoded tensor (batch, time, `*`).
265
+
266
+ """
267
+ self.extend_pe(x)
268
+ x = x * self.xscale
269
+ pos_emb = self.position_encoding(size=x.size(1), offset=offset)
270
+ return self.dropout(x), self.dropout(pos_emb)
271
+
272
+ def position_encoding(self,
273
+ offset: Union[int, torch.Tensor],
274
+ size: int) -> torch.Tensor:
275
+ """ For getting encoding in a streaming fashion
276
+
277
+ Attention!!!!!
278
+ we apply dropout only once at the whole utterance level in a none
279
+ streaming way, but will call this function several times with
280
+ increasing input size in a streaming scenario, so the dropout will
281
+ be applied several times.
282
+
283
+ Args:
284
+ offset (int or torch.tensor): start offset
285
+ size (int): required size of position encoding
286
+
287
+ Returns:
288
+ torch.Tensor: Corresponding encoding
289
+ """
290
+ # How to subscript a Union type:
291
+ # https://github.com/pytorch/pytorch/issues/69434
292
+ if isinstance(offset, int):
293
+ pos_emb = self.pe[
294
+ :,
295
+ self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset,
296
+ ]
297
+ elif isinstance(offset, torch.Tensor):
298
+ pos_emb = self.pe[
299
+ :,
300
+ self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset,
301
+ ]
302
+ return pos_emb
cosyvoice/transformer/encoder.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2022 Xingchen Song ([email protected])
3
+ # 2024 Alibaba Inc (Xiang Lyu)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # Modified from ESPnet(https://github.com/espnet/espnet)
17
+ """Encoder definition."""
18
+ from typing import Tuple
19
+
20
+ import torch
21
+ import torch.utils.checkpoint as ckpt
22
+
23
+ from cosyvoice.transformer.convolution import ConvolutionModule
24
+ from cosyvoice.transformer.encoder_layer import TransformerEncoderLayer
25
+ from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer
26
+ from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
27
+ from cosyvoice.utils.class_utils import (
28
+ COSYVOICE_EMB_CLASSES,
29
+ COSYVOICE_SUBSAMPLE_CLASSES,
30
+ COSYVOICE_ATTENTION_CLASSES,
31
+ COSYVOICE_ACTIVATION_CLASSES,
32
+ )
33
+ from cosyvoice.utils.mask import make_pad_mask
34
+ from cosyvoice.utils.mask import add_optional_chunk_mask
35
+
36
+
37
+ class BaseEncoder(torch.nn.Module):
38
+
39
+ def __init__(
40
+ self,
41
+ input_size: int,
42
+ output_size: int = 256,
43
+ attention_heads: int = 4,
44
+ linear_units: int = 2048,
45
+ num_blocks: int = 6,
46
+ dropout_rate: float = 0.1,
47
+ positional_dropout_rate: float = 0.1,
48
+ attention_dropout_rate: float = 0.0,
49
+ input_layer: str = "conv2d",
50
+ pos_enc_layer_type: str = "abs_pos",
51
+ normalize_before: bool = True,
52
+ static_chunk_size: int = 0,
53
+ use_dynamic_chunk: bool = False,
54
+ global_cmvn: torch.nn.Module = None,
55
+ use_dynamic_left_chunk: bool = False,
56
+ gradient_checkpointing: bool = False,
57
+ ):
58
+ """
59
+ Args:
60
+ input_size (int): input dim
61
+ output_size (int): dimension of attention
62
+ attention_heads (int): the number of heads of multi head attention
63
+ linear_units (int): the hidden units number of position-wise feed
64
+ forward
65
+ num_blocks (int): the number of decoder blocks
66
+ dropout_rate (float): dropout rate
67
+ attention_dropout_rate (float): dropout rate in attention
68
+ positional_dropout_rate (float): dropout rate after adding
69
+ positional encoding
70
+ input_layer (str): input layer type.
71
+ optional [linear, conv2d, conv2d6, conv2d8]
72
+ pos_enc_layer_type (str): Encoder positional encoding layer type.
73
+ opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
74
+ normalize_before (bool):
75
+ True: use layer_norm before each sub-block of a layer.
76
+ False: use layer_norm after each sub-block of a layer.
77
+ static_chunk_size (int): chunk size for static chunk training and
78
+ decoding
79
+ use_dynamic_chunk (bool): whether use dynamic chunk size for
80
+ training or not, You can only use fixed chunk(chunk_size > 0)
81
+ or dyanmic chunk size(use_dynamic_chunk = True)
82
+ global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
83
+ use_dynamic_left_chunk (bool): whether use dynamic left chunk in
84
+ dynamic chunk training
85
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
86
+ gradient_checkpointing: rerunning a forward-pass segment for each
87
+ checkpointed segment during backward.
88
+ """
89
+ super().__init__()
90
+ self._output_size = output_size
91
+
92
+ self.global_cmvn = global_cmvn
93
+ self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
94
+ input_size,
95
+ output_size,
96
+ dropout_rate,
97
+ COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
98
+ positional_dropout_rate),
99
+ )
100
+
101
+ self.normalize_before = normalize_before
102
+ self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
103
+ self.static_chunk_size = static_chunk_size
104
+ self.use_dynamic_chunk = use_dynamic_chunk
105
+ self.use_dynamic_left_chunk = use_dynamic_left_chunk
106
+ self.gradient_checkpointing = gradient_checkpointing
107
+
108
+ def output_size(self) -> int:
109
+ return self._output_size
110
+
111
+ def forward(
112
+ self,
113
+ xs: torch.Tensor,
114
+ xs_lens: torch.Tensor,
115
+ decoding_chunk_size: int = 0,
116
+ num_decoding_left_chunks: int = -1,
117
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
118
+ """Embed positions in tensor.
119
+
120
+ Args:
121
+ xs: padded input tensor (B, T, D)
122
+ xs_lens: input length (B)
123
+ decoding_chunk_size: decoding chunk size for dynamic chunk
124
+ 0: default for training, use random dynamic chunk.
125
+ <0: for decoding, use full chunk.
126
+ >0: for decoding, use fixed chunk size as set.
127
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
128
+ the chunk size is decoding_chunk_size.
129
+ >=0: use num_decoding_left_chunks
130
+ <0: use all left chunks
131
+ Returns:
132
+ encoder output tensor xs, and subsampled masks
133
+ xs: padded output tensor (B, T' ~= T/subsample_rate, D)
134
+ masks: torch.Tensor batch padding mask after subsample
135
+ (B, 1, T' ~= T/subsample_rate)
136
+ NOTE(xcsong):
137
+ We pass the `__call__` method of the modules instead of `forward` to the
138
+ checkpointing API because `__call__` attaches all the hooks of the module.
139
+ https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
140
+ """
141
+ T = xs.size(1)
142
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
143
+ if self.global_cmvn is not None:
144
+ xs = self.global_cmvn(xs)
145
+ xs, pos_emb, masks = self.embed(xs, masks)
146
+ mask_pad = masks # (B, 1, T/subsample_rate)
147
+ chunk_masks = add_optional_chunk_mask(xs, masks,
148
+ self.use_dynamic_chunk,
149
+ self.use_dynamic_left_chunk,
150
+ decoding_chunk_size,
151
+ self.static_chunk_size,
152
+ num_decoding_left_chunks)
153
+ if self.gradient_checkpointing and self.training:
154
+ xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb,
155
+ mask_pad)
156
+ else:
157
+ xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
158
+ if self.normalize_before:
159
+ xs = self.after_norm(xs)
160
+ # Here we assume the mask is not changed in encoder layers, so just
161
+ # return the masks before encoder layers, and the masks will be used
162
+ # for cross attention with decoder later
163
+ return xs, masks
164
+
165
+ def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
166
+ pos_emb: torch.Tensor,
167
+ mask_pad: torch.Tensor) -> torch.Tensor:
168
+ for layer in self.encoders:
169
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
170
+ return xs
171
+
172
+ @torch.jit.unused
173
+ def forward_layers_checkpointed(self, xs: torch.Tensor,
174
+ chunk_masks: torch.Tensor,
175
+ pos_emb: torch.Tensor,
176
+ mask_pad: torch.Tensor) -> torch.Tensor:
177
+ for layer in self.encoders:
178
+ xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__, xs,
179
+ chunk_masks, pos_emb,
180
+ mask_pad)
181
+ return xs
182
+
183
+ @torch.jit.export
184
+ def forward_chunk(
185
+ self,
186
+ xs: torch.Tensor,
187
+ offset: int,
188
+ required_cache_size: int,
189
+ att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
190
+ cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
191
+ att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
192
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
193
+ """ Forward just one chunk
194
+
195
+ Args:
196
+ xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim),
197
+ where `time == (chunk_size - 1) * subsample_rate + \
198
+ subsample.right_context + 1`
199
+ offset (int): current offset in encoder output time stamp
200
+ required_cache_size (int): cache size required for next chunk
201
+ compuation
202
+ >=0: actual cache size
203
+ <0: means all history cache is required
204
+ att_cache (torch.Tensor): cache tensor for KEY & VALUE in
205
+ transformer/conformer attention, with shape
206
+ (elayers, head, cache_t1, d_k * 2), where
207
+ `head * d_k == hidden-dim` and
208
+ `cache_t1 == chunk_size * num_decoding_left_chunks`.
209
+ cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
210
+ (elayers, b=1, hidden-dim, cache_t2), where
211
+ `cache_t2 == cnn.lorder - 1`
212
+
213
+ Returns:
214
+ torch.Tensor: output of current input xs,
215
+ with shape (b=1, chunk_size, hidden-dim).
216
+ torch.Tensor: new attention cache required for next chunk, with
217
+ dynamic shape (elayers, head, ?, d_k * 2)
218
+ depending on required_cache_size.
219
+ torch.Tensor: new conformer cnn cache required for next chunk, with
220
+ same shape as the original cnn_cache.
221
+
222
+ """
223
+ assert xs.size(0) == 1
224
+ # tmp_masks is just for interface compatibility
225
+ tmp_masks = torch.ones(1,
226
+ xs.size(1),
227
+ device=xs.device,
228
+ dtype=torch.bool)
229
+ tmp_masks = tmp_masks.unsqueeze(1)
230
+ if self.global_cmvn is not None:
231
+ xs = self.global_cmvn(xs)
232
+ # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
233
+ xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
234
+ # NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim)
235
+ elayers, cache_t1 = att_cache.size(0), att_cache.size(2)
236
+ chunk_size = xs.size(1)
237
+ attention_key_size = cache_t1 + chunk_size
238
+ pos_emb = self.embed.position_encoding(offset=offset - cache_t1,
239
+ size=attention_key_size)
240
+ if required_cache_size < 0:
241
+ next_cache_start = 0
242
+ elif required_cache_size == 0:
243
+ next_cache_start = attention_key_size
244
+ else:
245
+ next_cache_start = max(attention_key_size - required_cache_size, 0)
246
+ r_att_cache = []
247
+ r_cnn_cache = []
248
+ for i, layer in enumerate(self.encoders):
249
+ # NOTE(xcsong): Before layer.forward
250
+ # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
251
+ # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
252
+ xs, _, new_att_cache, new_cnn_cache = layer(
253
+ xs,
254
+ att_mask,
255
+ pos_emb,
256
+ att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
257
+ cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache)
258
+ # NOTE(xcsong): After layer.forward
259
+ # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2),
260
+ # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
261
+ r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
262
+ r_cnn_cache.append(new_cnn_cache.unsqueeze(0))
263
+ if self.normalize_before:
264
+ xs = self.after_norm(xs)
265
+
266
+ # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
267
+ # ? may be larger than cache_t1, it depends on required_cache_size
268
+ r_att_cache = torch.cat(r_att_cache, dim=0)
269
+ # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
270
+ r_cnn_cache = torch.cat(r_cnn_cache, dim=0)
271
+
272
+ return (xs, r_att_cache, r_cnn_cache)
273
+
274
+ @torch.jit.unused
275
+ def forward_chunk_by_chunk(
276
+ self,
277
+ xs: torch.Tensor,
278
+ decoding_chunk_size: int,
279
+ num_decoding_left_chunks: int = -1,
280
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
281
+ """ Forward input chunk by chunk with chunk_size like a streaming
282
+ fashion
283
+
284
+ Here we should pay special attention to computation cache in the
285
+ streaming style forward chunk by chunk. Three things should be taken
286
+ into account for computation in the current network:
287
+ 1. transformer/conformer encoder layers output cache
288
+ 2. convolution in conformer
289
+ 3. convolution in subsampling
290
+
291
+ However, we don't implement subsampling cache for:
292
+ 1. We can control subsampling module to output the right result by
293
+ overlapping input instead of cache left context, even though it
294
+ wastes some computation, but subsampling only takes a very
295
+ small fraction of computation in the whole model.
296
+ 2. Typically, there are several covolution layers with subsampling
297
+ in subsampling module, it is tricky and complicated to do cache
298
+ with different convolution layers with different subsampling
299
+ rate.
300
+ 3. Currently, nn.Sequential is used to stack all the convolution
301
+ layers in subsampling, we need to rewrite it to make it work
302
+ with cache, which is not preferred.
303
+ Args:
304
+ xs (torch.Tensor): (1, max_len, dim)
305
+ chunk_size (int): decoding chunk size
306
+ """
307
+ assert decoding_chunk_size > 0
308
+ # The model is trained by static or dynamic chunk
309
+ assert self.static_chunk_size > 0 or self.use_dynamic_chunk
310
+ subsampling = self.embed.subsampling_rate
311
+ context = self.embed.right_context + 1 # Add current frame
312
+ stride = subsampling * decoding_chunk_size
313
+ decoding_window = (decoding_chunk_size - 1) * subsampling + context
314
+ num_frames = xs.size(1)
315
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
316
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
317
+ outputs = []
318
+ offset = 0
319
+ required_cache_size = decoding_chunk_size * num_decoding_left_chunks
320
+
321
+ # Feed forward overlap input step by step
322
+ for cur in range(0, num_frames - context + 1, stride):
323
+ end = min(cur + decoding_window, num_frames)
324
+ chunk_xs = xs[:, cur:end, :]
325
+ (y, att_cache,
326
+ cnn_cache) = self.forward_chunk(chunk_xs, offset,
327
+ required_cache_size, att_cache,
328
+ cnn_cache)
329
+ outputs.append(y)
330
+ offset += y.size(1)
331
+ ys = torch.cat(outputs, 1)
332
+ masks = torch.ones((1, 1, ys.size(1)),
333
+ device=ys.device,
334
+ dtype=torch.bool)
335
+ return ys, masks
336
+
337
+
338
+ class TransformerEncoder(BaseEncoder):
339
+ """Transformer encoder module."""
340
+
341
+ def __init__(
342
+ self,
343
+ input_size: int,
344
+ output_size: int = 256,
345
+ attention_heads: int = 4,
346
+ linear_units: int = 2048,
347
+ num_blocks: int = 6,
348
+ dropout_rate: float = 0.1,
349
+ positional_dropout_rate: float = 0.1,
350
+ attention_dropout_rate: float = 0.0,
351
+ input_layer: str = "conv2d",
352
+ pos_enc_layer_type: str = "abs_pos",
353
+ normalize_before: bool = True,
354
+ static_chunk_size: int = 0,
355
+ use_dynamic_chunk: bool = False,
356
+ global_cmvn: torch.nn.Module = None,
357
+ use_dynamic_left_chunk: bool = False,
358
+ key_bias: bool = True,
359
+ selfattention_layer_type: str = "selfattn",
360
+ activation_type: str = "relu",
361
+ gradient_checkpointing: bool = False,
362
+ ):
363
+ """ Construct TransformerEncoder
364
+
365
+ See Encoder for the meaning of each parameter.
366
+ """
367
+ super().__init__(input_size, output_size, attention_heads,
368
+ linear_units, num_blocks, dropout_rate,
369
+ positional_dropout_rate, attention_dropout_rate,
370
+ input_layer, pos_enc_layer_type, normalize_before,
371
+ static_chunk_size, use_dynamic_chunk, global_cmvn,
372
+ use_dynamic_left_chunk, gradient_checkpointing)
373
+ activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
374
+ self.encoders = torch.nn.ModuleList([
375
+ TransformerEncoderLayer(
376
+ output_size,
377
+ COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](attention_heads,
378
+ output_size,
379
+ attention_dropout_rate,
380
+ key_bias),
381
+ PositionwiseFeedForward(output_size, linear_units,
382
+ dropout_rate, activation),
383
+ dropout_rate, normalize_before) for _ in range(num_blocks)
384
+ ])
385
+
386
+
387
+ class ConformerEncoder(BaseEncoder):
388
+ """Conformer encoder module."""
389
+
390
+ def __init__(
391
+ self,
392
+ input_size: int,
393
+ output_size: int = 256,
394
+ attention_heads: int = 4,
395
+ linear_units: int = 2048,
396
+ num_blocks: int = 6,
397
+ dropout_rate: float = 0.1,
398
+ positional_dropout_rate: float = 0.1,
399
+ attention_dropout_rate: float = 0.0,
400
+ input_layer: str = "conv2d",
401
+ pos_enc_layer_type: str = "rel_pos",
402
+ normalize_before: bool = True,
403
+ static_chunk_size: int = 0,
404
+ use_dynamic_chunk: bool = False,
405
+ global_cmvn: torch.nn.Module = None,
406
+ use_dynamic_left_chunk: bool = False,
407
+ positionwise_conv_kernel_size: int = 1,
408
+ macaron_style: bool = True,
409
+ selfattention_layer_type: str = "rel_selfattn",
410
+ activation_type: str = "swish",
411
+ use_cnn_module: bool = True,
412
+ cnn_module_kernel: int = 15,
413
+ causal: bool = False,
414
+ cnn_module_norm: str = "batch_norm",
415
+ key_bias: bool = True,
416
+ gradient_checkpointing: bool = False,
417
+ ):
418
+ """Construct ConformerEncoder
419
+
420
+ Args:
421
+ input_size to use_dynamic_chunk, see in BaseEncoder
422
+ positionwise_conv_kernel_size (int): Kernel size of positionwise
423
+ conv1d layer.
424
+ macaron_style (bool): Whether to use macaron style for
425
+ positionwise layer.
426
+ selfattention_layer_type (str): Encoder attention layer type,
427
+ the parameter has no effect now, it's just for configure
428
+ compatibility.
429
+ activation_type (str): Encoder activation function type.
430
+ use_cnn_module (bool): Whether to use convolution module.
431
+ cnn_module_kernel (int): Kernel size of convolution module.
432
+ causal (bool): whether to use causal convolution or not.
433
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
434
+ """
435
+ super().__init__(input_size, output_size, attention_heads,
436
+ linear_units, num_blocks, dropout_rate,
437
+ positional_dropout_rate, attention_dropout_rate,
438
+ input_layer, pos_enc_layer_type, normalize_before,
439
+ static_chunk_size, use_dynamic_chunk, global_cmvn,
440
+ use_dynamic_left_chunk, gradient_checkpointing)
441
+ activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
442
+
443
+ # self-attention module definition
444
+ encoder_selfattn_layer_args = (
445
+ attention_heads,
446
+ output_size,
447
+ attention_dropout_rate,
448
+ key_bias,
449
+ )
450
+ # feed-forward module definition
451
+ positionwise_layer_args = (
452
+ output_size,
453
+ linear_units,
454
+ dropout_rate,
455
+ activation,
456
+ )
457
+ # convolution module definition
458
+ convolution_layer_args = (output_size, cnn_module_kernel, activation,
459
+ cnn_module_norm, causal)
460
+
461
+ self.encoders = torch.nn.ModuleList([
462
+ ConformerEncoderLayer(
463
+ output_size,
464
+ COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
465
+ *encoder_selfattn_layer_args),
466
+ PositionwiseFeedForward(*positionwise_layer_args),
467
+ PositionwiseFeedForward(
468
+ *positionwise_layer_args) if macaron_style else None,
469
+ ConvolutionModule(
470
+ *convolution_layer_args) if use_cnn_module else None,
471
+ dropout_rate,
472
+ normalize_before,
473
+ ) for _ in range(num_blocks)
474
+ ])
cosyvoice/transformer/encoder_layer.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2022 Xingchen Song ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Encoder self-attention layer definition."""
17
+
18
+ from typing import Optional, Tuple
19
+
20
+ import torch
21
+ from torch import nn
22
+
23
+
24
+ class TransformerEncoderLayer(nn.Module):
25
+ """Encoder layer module.
26
+
27
+ Args:
28
+ size (int): Input dimension.
29
+ self_attn (torch.nn.Module): Self-attention module instance.
30
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
31
+ instance can be used as the argument.
32
+ feed_forward (torch.nn.Module): Feed-forward module instance.
33
+ `PositionwiseFeedForward`, instance can be used as the argument.
34
+ dropout_rate (float): Dropout rate.
35
+ normalize_before (bool):
36
+ True: use layer_norm before each sub-block.
37
+ False: to use layer_norm after each sub-block.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ size: int,
43
+ self_attn: torch.nn.Module,
44
+ feed_forward: torch.nn.Module,
45
+ dropout_rate: float,
46
+ normalize_before: bool = True,
47
+ ):
48
+ """Construct an EncoderLayer object."""
49
+ super().__init__()
50
+ self.self_attn = self_attn
51
+ self.feed_forward = feed_forward
52
+ self.norm1 = nn.LayerNorm(size, eps=1e-12)
53
+ self.norm2 = nn.LayerNorm(size, eps=1e-12)
54
+ self.dropout = nn.Dropout(dropout_rate)
55
+ self.size = size
56
+ self.normalize_before = normalize_before
57
+
58
+ def forward(
59
+ self,
60
+ x: torch.Tensor,
61
+ mask: torch.Tensor,
62
+ pos_emb: torch.Tensor,
63
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
64
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
65
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
66
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
67
+ """Compute encoded features.
68
+
69
+ Args:
70
+ x (torch.Tensor): (#batch, time, size)
71
+ mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
72
+ (0, 0, 0) means fake mask.
73
+ pos_emb (torch.Tensor): just for interface compatibility
74
+ to ConformerEncoderLayer
75
+ mask_pad (torch.Tensor): does not used in transformer layer,
76
+ just for unified api with conformer.
77
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
78
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
79
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
80
+ (#batch=1, size, cache_t2), not used here, it's for interface
81
+ compatibility to ConformerEncoderLayer.
82
+ Returns:
83
+ torch.Tensor: Output tensor (#batch, time, size).
84
+ torch.Tensor: Mask tensor (#batch, time, time).
85
+ torch.Tensor: att_cache tensor,
86
+ (#batch=1, head, cache_t1 + time, d_k * 2).
87
+ torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2).
88
+
89
+ """
90
+ residual = x
91
+ if self.normalize_before:
92
+ x = self.norm1(x)
93
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb=pos_emb, cache=att_cache)
94
+ x = residual + self.dropout(x_att)
95
+ if not self.normalize_before:
96
+ x = self.norm1(x)
97
+
98
+ residual = x
99
+ if self.normalize_before:
100
+ x = self.norm2(x)
101
+ x = residual + self.dropout(self.feed_forward(x))
102
+ if not self.normalize_before:
103
+ x = self.norm2(x)
104
+
105
+ fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
106
+ return x, mask, new_att_cache, fake_cnn_cache
107
+
108
+
109
+ class ConformerEncoderLayer(nn.Module):
110
+ """Encoder layer module.
111
+ Args:
112
+ size (int): Input dimension.
113
+ self_attn (torch.nn.Module): Self-attention module instance.
114
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
115
+ instance can be used as the argument.
116
+ feed_forward (torch.nn.Module): Feed-forward module instance.
117
+ `PositionwiseFeedForward` instance can be used as the argument.
118
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module
119
+ instance.
120
+ `PositionwiseFeedForward` instance can be used as the argument.
121
+ conv_module (torch.nn.Module): Convolution module instance.
122
+ `ConvlutionModule` instance can be used as the argument.
123
+ dropout_rate (float): Dropout rate.
124
+ normalize_before (bool):
125
+ True: use layer_norm before each sub-block.
126
+ False: use layer_norm after each sub-block.
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ size: int,
132
+ self_attn: torch.nn.Module,
133
+ feed_forward: Optional[nn.Module] = None,
134
+ feed_forward_macaron: Optional[nn.Module] = None,
135
+ conv_module: Optional[nn.Module] = None,
136
+ dropout_rate: float = 0.1,
137
+ normalize_before: bool = True,
138
+ ):
139
+ """Construct an EncoderLayer object."""
140
+ super().__init__()
141
+ self.self_attn = self_attn
142
+ self.feed_forward = feed_forward
143
+ self.feed_forward_macaron = feed_forward_macaron
144
+ self.conv_module = conv_module
145
+ self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module
146
+ self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module
147
+ if feed_forward_macaron is not None:
148
+ self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
149
+ self.ff_scale = 0.5
150
+ else:
151
+ self.ff_scale = 1.0
152
+ if self.conv_module is not None:
153
+ self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module
154
+ self.norm_final = nn.LayerNorm(
155
+ size, eps=1e-12) # for the final output of the block
156
+ self.dropout = nn.Dropout(dropout_rate)
157
+ self.size = size
158
+ self.normalize_before = normalize_before
159
+
160
+ def forward(
161
+ self,
162
+ x: torch.Tensor,
163
+ mask: torch.Tensor,
164
+ pos_emb: torch.Tensor,
165
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
166
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
167
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
168
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
169
+ """Compute encoded features.
170
+
171
+ Args:
172
+ x (torch.Tensor): (#batch, time, size)
173
+ mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
174
+ (0, 0, 0) means fake mask.
175
+ pos_emb (torch.Tensor): positional encoding, must not be None
176
+ for ConformerEncoderLayer.
177
+ mask_pad (torch.Tensor): batch padding mask used for conv module.
178
+ (#batch, 1,time), (0, 0, 0) means fake mask.
179
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
180
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
181
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
182
+ (#batch=1, size, cache_t2)
183
+ Returns:
184
+ torch.Tensor: Output tensor (#batch, time, size).
185
+ torch.Tensor: Mask tensor (#batch, time, time).
186
+ torch.Tensor: att_cache tensor,
187
+ (#batch=1, head, cache_t1 + time, d_k * 2).
188
+ torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
189
+ """
190
+
191
+ # whether to use macaron style
192
+ if self.feed_forward_macaron is not None:
193
+ residual = x
194
+ if self.normalize_before:
195
+ x = self.norm_ff_macaron(x)
196
+ x = residual + self.ff_scale * self.dropout(
197
+ self.feed_forward_macaron(x))
198
+ if not self.normalize_before:
199
+ x = self.norm_ff_macaron(x)
200
+
201
+ # multi-headed self-attention module
202
+ residual = x
203
+ if self.normalize_before:
204
+ x = self.norm_mha(x)
205
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
206
+ att_cache)
207
+ x = residual + self.dropout(x_att)
208
+ if not self.normalize_before:
209
+ x = self.norm_mha(x)
210
+
211
+ # convolution module
212
+ # Fake new cnn cache here, and then change it in conv_module
213
+ new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
214
+ if self.conv_module is not None:
215
+ residual = x
216
+ if self.normalize_before:
217
+ x = self.norm_conv(x)
218
+ x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
219
+ x = residual + self.dropout(x)
220
+
221
+ if not self.normalize_before:
222
+ x = self.norm_conv(x)
223
+
224
+ # feed forward module
225
+ residual = x
226
+ if self.normalize_before:
227
+ x = self.norm_ff(x)
228
+
229
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
230
+ if not self.normalize_before:
231
+ x = self.norm_ff(x)
232
+
233
+ if self.conv_module is not None:
234
+ x = self.norm_final(x)
235
+
236
+ return x, mask, new_att_cache, new_cnn_cache
cosyvoice/transformer/label_smoothing_loss.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Label smoothing module."""
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+
21
+ class LabelSmoothingLoss(nn.Module):
22
+ """Label-smoothing loss.
23
+
24
+ In a standard CE loss, the label's data distribution is:
25
+ [0,1,2] ->
26
+ [
27
+ [1.0, 0.0, 0.0],
28
+ [0.0, 1.0, 0.0],
29
+ [0.0, 0.0, 1.0],
30
+ ]
31
+
32
+ In the smoothing version CE Loss,some probabilities
33
+ are taken from the true label prob (1.0) and are divided
34
+ among other labels.
35
+
36
+ e.g.
37
+ smoothing=0.1
38
+ [0,1,2] ->
39
+ [
40
+ [0.9, 0.05, 0.05],
41
+ [0.05, 0.9, 0.05],
42
+ [0.05, 0.05, 0.9],
43
+ ]
44
+
45
+ Args:
46
+ size (int): the number of class
47
+ padding_idx (int): padding class id which will be ignored for loss
48
+ smoothing (float): smoothing rate (0.0 means the conventional CE)
49
+ normalize_length (bool):
50
+ normalize loss by sequence length if True
51
+ normalize loss by batch size if False
52
+ """
53
+
54
+ def __init__(self,
55
+ size: int,
56
+ padding_idx: int,
57
+ smoothing: float,
58
+ normalize_length: bool = False):
59
+ """Construct an LabelSmoothingLoss object."""
60
+ super(LabelSmoothingLoss, self).__init__()
61
+ self.criterion = nn.KLDivLoss(reduction="none")
62
+ self.padding_idx = padding_idx
63
+ self.confidence = 1.0 - smoothing
64
+ self.smoothing = smoothing
65
+ self.size = size
66
+ self.normalize_length = normalize_length
67
+
68
+ def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
69
+ """Compute loss between x and target.
70
+
71
+ The model outputs and data labels tensors are flatten to
72
+ (batch*seqlen, class) shape and a mask is applied to the
73
+ padding part which should not be calculated for loss.
74
+
75
+ Args:
76
+ x (torch.Tensor): prediction (batch, seqlen, class)
77
+ target (torch.Tensor):
78
+ target signal masked with self.padding_id (batch, seqlen)
79
+ Returns:
80
+ loss (torch.Tensor) : The KL loss, scalar float value
81
+ """
82
+ assert x.size(2) == self.size
83
+ batch_size = x.size(0)
84
+ x = x.view(-1, self.size)
85
+ target = target.view(-1)
86
+ # use zeros_like instead of torch.no_grad() for true_dist,
87
+ # since no_grad() can not be exported by JIT
88
+ true_dist = torch.zeros_like(x)
89
+ true_dist.fill_(self.smoothing / (self.size - 1))
90
+ ignore = target == self.padding_idx # (B,)
91
+ total = len(target) - ignore.sum().item()
92
+ target = target.masked_fill(ignore, 0) # avoid -1 index
93
+ true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
94
+ kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
95
+ denom = total if self.normalize_length else batch_size
96
+ return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
cosyvoice/transformer/positionwise_feed_forward.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Positionwise feed forward layer definition."""
16
+
17
+ import torch
18
+
19
+
20
+ class PositionwiseFeedForward(torch.nn.Module):
21
+ """Positionwise feed forward layer.
22
+
23
+ FeedForward are appied on each position of the sequence.
24
+ The output dim is same with the input dim.
25
+
26
+ Args:
27
+ idim (int): Input dimenstion.
28
+ hidden_units (int): The number of hidden units.
29
+ dropout_rate (float): Dropout rate.
30
+ activation (torch.nn.Module): Activation function
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ idim: int,
36
+ hidden_units: int,
37
+ dropout_rate: float,
38
+ activation: torch.nn.Module = torch.nn.ReLU(),
39
+ ):
40
+ """Construct a PositionwiseFeedForward object."""
41
+ super(PositionwiseFeedForward, self).__init__()
42
+ self.w_1 = torch.nn.Linear(idim, hidden_units)
43
+ self.activation = activation
44
+ self.dropout = torch.nn.Dropout(dropout_rate)
45
+ self.w_2 = torch.nn.Linear(hidden_units, idim)
46
+
47
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
48
+ """Forward function.
49
+
50
+ Args:
51
+ xs: input tensor (B, L, D)
52
+ Returns:
53
+ output tensor, (B, L, D)
54
+ """
55
+ return self.w_2(self.dropout(self.activation(self.w_1(xs))))
56
+
57
+
58
+ class MoEFFNLayer(torch.nn.Module):
59
+ """
60
+ Mixture of expert with Positionwise feed forward layer
61
+ See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf
62
+ The output dim is same with the input dim.
63
+
64
+ Modified from https://github.com/Lightning-AI/lit-gpt/pull/823
65
+ https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
66
+ Args:
67
+ n_expert: number of expert.
68
+ n_expert_per_token: The actual number of experts used for each frame
69
+ idim (int): Input dimenstion.
70
+ hidden_units (int): The number of hidden units.
71
+ dropout_rate (float): Dropout rate.
72
+ activation (torch.nn.Module): Activation function
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ n_expert: int,
78
+ n_expert_per_token: int,
79
+ idim: int,
80
+ hidden_units: int,
81
+ dropout_rate: float,
82
+ activation: torch.nn.Module = torch.nn.ReLU(),
83
+ ):
84
+ super(MoEFFNLayer, self).__init__()
85
+ self.gate = torch.nn.Linear(idim, n_expert, bias=False)
86
+ self.experts = torch.nn.ModuleList(
87
+ PositionwiseFeedForward(idim, hidden_units, dropout_rate,
88
+ activation) for _ in range(n_expert))
89
+ self.n_expert_per_token = n_expert_per_token
90
+
91
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
92
+ """Foward function.
93
+ Args:
94
+ xs: input tensor (B, L, D)
95
+ Returns:
96
+ output tensor, (B, L, D)
97
+
98
+ """
99
+ B, L, D = xs.size(
100
+ ) # batch size, sequence length, embedding dimension (idim)
101
+ xs = xs.view(-1, D) # (B*L, D)
102
+ router = self.gate(xs) # (B*L, n_expert)
103
+ logits, indices = torch.topk(
104
+ router, self.n_expert_per_token
105
+ ) # probs:(B*L, n_expert), indices: (B*L, n_expert)
106
+ weights = torch.nn.functional.softmax(
107
+ logits, dim=1,
108
+ dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token)
109
+ output = torch.zeros_like(xs) # (B*L, D)
110
+ for i, expert in enumerate(self.experts):
111
+ mask = indices == i
112
+ batch_idx, ith_expert = torch.where(mask)
113
+ output[batch_idx] += weights[batch_idx, ith_expert, None] * expert(
114
+ xs[batch_idx])
115
+ return output.view(B, L, D)
cosyvoice/transformer/subsampling.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Subsampling layer definition."""
17
+
18
+ from typing import Tuple, Union
19
+
20
+ import torch
21
+
22
+
23
+ class BaseSubsampling(torch.nn.Module):
24
+
25
+ def __init__(self):
26
+ super().__init__()
27
+ self.right_context = 0
28
+ self.subsampling_rate = 1
29
+
30
+ def position_encoding(self, offset: Union[int, torch.Tensor],
31
+ size: int) -> torch.Tensor:
32
+ return self.pos_enc.position_encoding(offset, size)
33
+
34
+
35
+ class EmbedinigNoSubsampling(BaseSubsampling):
36
+ """Embedding input without subsampling
37
+ """
38
+
39
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
40
+ pos_enc_class: torch.nn.Module):
41
+ super().__init__()
42
+ self.embed = torch.nn.Embedding(idim, odim)
43
+ self.pos_enc = pos_enc_class
44
+
45
+ def forward(
46
+ self,
47
+ x: torch.Tensor,
48
+ x_mask: torch.Tensor,
49
+ offset: Union[int, torch.Tensor] = 0
50
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
51
+ """Input x.
52
+
53
+ Args:
54
+ x (torch.Tensor): Input tensor (#batch, time, idim).
55
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
56
+
57
+ Returns:
58
+ torch.Tensor: linear input tensor (#batch, time', odim),
59
+ where time' = time .
60
+ torch.Tensor: linear input mask (#batch, 1, time'),
61
+ where time' = time .
62
+
63
+ """
64
+ x = self.embed(x)
65
+ x, pos_emb = self.pos_enc(x, offset)
66
+ return x, pos_emb, x_mask
67
+
68
+
69
+ class LinearNoSubsampling(BaseSubsampling):
70
+ """Linear transform the input without subsampling
71
+
72
+ Args:
73
+ idim (int): Input dimension.
74
+ odim (int): Output dimension.
75
+ dropout_rate (float): Dropout rate.
76
+
77
+ """
78
+
79
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
80
+ pos_enc_class: torch.nn.Module):
81
+ """Construct an linear object."""
82
+ super().__init__()
83
+ self.out = torch.nn.Sequential(
84
+ torch.nn.Linear(idim, odim),
85
+ torch.nn.LayerNorm(odim, eps=1e-5),
86
+ torch.nn.Dropout(dropout_rate),
87
+ )
88
+ self.pos_enc = pos_enc_class
89
+ self.right_context = 0
90
+ self.subsampling_rate = 1
91
+
92
+ def forward(
93
+ self,
94
+ x: torch.Tensor,
95
+ x_mask: torch.Tensor,
96
+ offset: Union[int, torch.Tensor] = 0
97
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
98
+ """Input x.
99
+
100
+ Args:
101
+ x (torch.Tensor): Input tensor (#batch, time, idim).
102
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
103
+
104
+ Returns:
105
+ torch.Tensor: linear input tensor (#batch, time', odim),
106
+ where time' = time .
107
+ torch.Tensor: linear input mask (#batch, 1, time'),
108
+ where time' = time .
109
+
110
+ """
111
+ x = self.out(x)
112
+ x, pos_emb = self.pos_enc(x, offset)
113
+ return x, pos_emb, x_mask
114
+
115
+
116
+ class Conv1dSubsampling2(BaseSubsampling):
117
+ """Convolutional 1D subsampling (to 1/2 length).
118
+ It is designed for Whisper, ref:
119
+ https://github.com/openai/whisper/blob/main/whisper/model.py
120
+
121
+ Args:
122
+ idim (int): Input dimension.
123
+ odim (int): Output dimension.
124
+ dropout_rate (float): Dropout rate.
125
+
126
+ """
127
+
128
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
129
+ pos_enc_class: torch.nn.Module):
130
+ """Construct an Conv1dSubsampling2 object."""
131
+ super().__init__()
132
+ self.conv = torch.nn.Sequential(
133
+ torch.nn.Conv1d(idim, odim, kernel_size=3, padding=1),
134
+ torch.nn.GELU(),
135
+ torch.nn.Conv1d(odim, odim, kernel_size=3, stride=2, padding=1),
136
+ torch.nn.GELU(),
137
+ )
138
+ self.pos_enc = pos_enc_class
139
+ # The right context for every conv layer is computed by:
140
+ # (kernel_size - 1) * frame_rate_of_this_layer
141
+ self.subsampling_rate = 2
142
+ # 4 = (3 - 1) * 1 + (3 - 1) * 1
143
+ self.right_context = 4
144
+
145
+ def forward(
146
+ self,
147
+ x: torch.Tensor,
148
+ x_mask: torch.Tensor,
149
+ offset: Union[int, torch.Tensor] = 0
150
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
151
+ """Subsample x.
152
+
153
+ Args:
154
+ x (torch.Tensor): Input tensor (#batch, time, idim).
155
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
156
+
157
+ Returns:
158
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
159
+ where time' = time // 2.
160
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
161
+ where time' = time // 2.
162
+ torch.Tensor: positional encoding
163
+
164
+ """
165
+ time = x.size(1)
166
+ x = x.transpose(1, 2) # (b, f, t)
167
+ x = self.conv(x)
168
+ x = x.transpose(1, 2) # (b, t, f)
169
+ x, pos_emb = self.pos_enc(x, offset)
170
+ return x, pos_emb, x_mask[:, :, (time + 1) % 2::2]
171
+
172
+
173
+ class Conv2dSubsampling4(BaseSubsampling):
174
+ """Convolutional 2D subsampling (to 1/4 length).
175
+
176
+ Args:
177
+ idim (int): Input dimension.
178
+ odim (int): Output dimension.
179
+ dropout_rate (float): Dropout rate.
180
+
181
+ """
182
+
183
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
184
+ pos_enc_class: torch.nn.Module):
185
+ """Construct an Conv2dSubsampling4 object."""
186
+ super().__init__()
187
+ self.conv = torch.nn.Sequential(
188
+ torch.nn.Conv2d(1, odim, 3, 2),
189
+ torch.nn.ReLU(),
190
+ torch.nn.Conv2d(odim, odim, 3, 2),
191
+ torch.nn.ReLU(),
192
+ )
193
+ self.out = torch.nn.Sequential(
194
+ torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
195
+ self.pos_enc = pos_enc_class
196
+ # The right context for every conv layer is computed by:
197
+ # (kernel_size - 1) * frame_rate_of_this_layer
198
+ self.subsampling_rate = 4
199
+ # 6 = (3 - 1) * 1 + (3 - 1) * 2
200
+ self.right_context = 6
201
+
202
+ def forward(
203
+ self,
204
+ x: torch.Tensor,
205
+ x_mask: torch.Tensor,
206
+ offset: Union[int, torch.Tensor] = 0
207
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
208
+ """Subsample x.
209
+
210
+ Args:
211
+ x (torch.Tensor): Input tensor (#batch, time, idim).
212
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
213
+
214
+ Returns:
215
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
216
+ where time' = time // 4.
217
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
218
+ where time' = time // 4.
219
+ torch.Tensor: positional encoding
220
+
221
+ """
222
+ x = x.unsqueeze(1) # (b, c=1, t, f)
223
+ x = self.conv(x)
224
+ b, c, t, f = x.size()
225
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
226
+ x, pos_emb = self.pos_enc(x, offset)
227
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2]
228
+
229
+
230
+ class Conv2dSubsampling6(BaseSubsampling):
231
+ """Convolutional 2D subsampling (to 1/6 length).
232
+ Args:
233
+ idim (int): Input dimension.
234
+ odim (int): Output dimension.
235
+ dropout_rate (float): Dropout rate.
236
+ pos_enc (torch.nn.Module): Custom position encoding layer.
237
+ """
238
+
239
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
240
+ pos_enc_class: torch.nn.Module):
241
+ """Construct an Conv2dSubsampling6 object."""
242
+ super().__init__()
243
+ self.conv = torch.nn.Sequential(
244
+ torch.nn.Conv2d(1, odim, 3, 2),
245
+ torch.nn.ReLU(),
246
+ torch.nn.Conv2d(odim, odim, 5, 3),
247
+ torch.nn.ReLU(),
248
+ )
249
+ self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3),
250
+ odim)
251
+ self.pos_enc = pos_enc_class
252
+ # 10 = (3 - 1) * 1 + (5 - 1) * 2
253
+ self.subsampling_rate = 6
254
+ self.right_context = 10
255
+
256
+ def forward(
257
+ self,
258
+ x: torch.Tensor,
259
+ x_mask: torch.Tensor,
260
+ offset: Union[int, torch.Tensor] = 0
261
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
262
+ """Subsample x.
263
+ Args:
264
+ x (torch.Tensor): Input tensor (#batch, time, idim).
265
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
266
+
267
+ Returns:
268
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
269
+ where time' = time // 6.
270
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
271
+ where time' = time // 6.
272
+ torch.Tensor: positional encoding
273
+ """
274
+ x = x.unsqueeze(1) # (b, c, t, f)
275
+ x = self.conv(x)
276
+ b, c, t, f = x.size()
277
+ x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
278
+ x, pos_emb = self.pos_enc(x, offset)
279
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]
280
+
281
+
282
+ class Conv2dSubsampling8(BaseSubsampling):
283
+ """Convolutional 2D subsampling (to 1/8 length).
284
+
285
+ Args:
286
+ idim (int): Input dimension.
287
+ odim (int): Output dimension.
288
+ dropout_rate (float): Dropout rate.
289
+
290
+ """
291
+
292
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
293
+ pos_enc_class: torch.nn.Module):
294
+ """Construct an Conv2dSubsampling8 object."""
295
+ super().__init__()
296
+ self.conv = torch.nn.Sequential(
297
+ torch.nn.Conv2d(1, odim, 3, 2),
298
+ torch.nn.ReLU(),
299
+ torch.nn.Conv2d(odim, odim, 3, 2),
300
+ torch.nn.ReLU(),
301
+ torch.nn.Conv2d(odim, odim, 3, 2),
302
+ torch.nn.ReLU(),
303
+ )
304
+ self.linear = torch.nn.Linear(
305
+ odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)
306
+ self.pos_enc = pos_enc_class
307
+ self.subsampling_rate = 8
308
+ # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
309
+ self.right_context = 14
310
+
311
+ def forward(
312
+ self,
313
+ x: torch.Tensor,
314
+ x_mask: torch.Tensor,
315
+ offset: Union[int, torch.Tensor] = 0
316
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
317
+ """Subsample x.
318
+
319
+ Args:
320
+ x (torch.Tensor): Input tensor (#batch, time, idim).
321
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
322
+
323
+ Returns:
324
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
325
+ where time' = time // 8.
326
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
327
+ where time' = time // 8.
328
+ torch.Tensor: positional encoding
329
+ """
330
+ x = x.unsqueeze(1) # (b, c, t, f)
331
+ x = self.conv(x)
332
+ b, c, t, f = x.size()
333
+ x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
334
+ x, pos_emb = self.pos_enc(x, offset)
335
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]
336
+
337
+
338
+ class LegacyLinearNoSubsampling(BaseSubsampling):
339
+ """Linear transform the input without subsampling
340
+
341
+ Args:
342
+ idim (int): Input dimension.
343
+ odim (int): Output dimension.
344
+ dropout_rate (float): Dropout rate.
345
+
346
+ """
347
+
348
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
349
+ pos_enc_class: torch.nn.Module):
350
+ """Construct an linear object."""
351
+ super().__init__()
352
+ self.out = torch.nn.Sequential(
353
+ torch.nn.Linear(idim, odim),
354
+ torch.nn.LayerNorm(odim, eps=1e-5),
355
+ torch.nn.Dropout(dropout_rate),
356
+ torch.nn.ReLU(),
357
+ )
358
+ self.pos_enc = pos_enc_class
359
+ self.right_context = 0
360
+ self.subsampling_rate = 1
361
+
362
+ def forward(
363
+ self,
364
+ x: torch.Tensor,
365
+ x_mask: torch.Tensor,
366
+ offset: Union[int, torch.Tensor] = 0
367
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
368
+ """Input x.
369
+
370
+ Args:
371
+ x (torch.Tensor): Input tensor (#batch, time, idim).
372
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
373
+
374
+ Returns:
375
+ torch.Tensor: linear input tensor (#batch, time', odim),
376
+ where time' = time .
377
+ torch.Tensor: linear input mask (#batch, 1, time'),
378
+ where time' = time .
379
+
380
+ """
381
+ x = self.out(x)
382
+ x, pos_emb = self.pos_enc(x, offset)
383
+ return x, pos_emb, x_mask
cosyvoice/transformer/upsample_encoder.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2022 Xingchen Song ([email protected])
3
+ # 2024 Alibaba Inc (Xiang Lyu)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # Modified from ESPnet(https://github.com/espnet/espnet)
17
+ """Encoder definition."""
18
+ from typing import Tuple
19
+
20
+ import torch
21
+ from torch import nn
22
+ from torch.nn import functional as F
23
+
24
+ from cosyvoice.transformer.convolution import ConvolutionModule
25
+ from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer
26
+ from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
27
+ from cosyvoice.utils.class_utils import (
28
+ COSYVOICE_EMB_CLASSES,
29
+ COSYVOICE_SUBSAMPLE_CLASSES,
30
+ COSYVOICE_ATTENTION_CLASSES,
31
+ COSYVOICE_ACTIVATION_CLASSES,
32
+ )
33
+ from cosyvoice.utils.mask import make_pad_mask
34
+ from cosyvoice.utils.mask import add_optional_chunk_mask
35
+
36
+
37
+ class Upsample1D(nn.Module):
38
+ """A 1D upsampling layer with an optional convolution.
39
+
40
+ Parameters:
41
+ channels (`int`):
42
+ number of channels in the inputs and outputs.
43
+ use_conv (`bool`, default `False`):
44
+ option to use a convolution.
45
+ use_conv_transpose (`bool`, default `False`):
46
+ option to use a convolution transpose.
47
+ out_channels (`int`, optional):
48
+ number of output channels. Defaults to `channels`.
49
+ """
50
+
51
+ def __init__(self, channels: int, out_channels: int, stride: int = 2):
52
+ super().__init__()
53
+ self.channels = channels
54
+ self.out_channels = out_channels
55
+ self.stride = stride
56
+ # In this mode, first repeat interpolate, than conv with stride=1
57
+ self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
58
+
59
+ def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
60
+ outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
61
+ outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
62
+ outputs = self.conv(outputs)
63
+ return outputs, input_lengths * self.stride
64
+
65
+
66
+ class PreLookaheadLayer(nn.Module):
67
+ def __init__(self, in_channels: int, channels: int, pre_lookahead_len: int = 1):
68
+ super().__init__()
69
+ self.in_channels = in_channels
70
+ self.channels = channels
71
+ self.pre_lookahead_len = pre_lookahead_len
72
+ self.conv1 = nn.Conv1d(
73
+ in_channels, channels,
74
+ kernel_size=pre_lookahead_len + 1,
75
+ stride=1, padding=0,
76
+ )
77
+ self.conv2 = nn.Conv1d(
78
+ channels, in_channels,
79
+ kernel_size=3, stride=1, padding=0,
80
+ )
81
+
82
+ def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0)) -> torch.Tensor:
83
+ """
84
+ inputs: (batch_size, seq_len, channels)
85
+ """
86
+ outputs = inputs.transpose(1, 2).contiguous()
87
+ context = context.transpose(1, 2).contiguous()
88
+ # look ahead
89
+ if context.size(2) == 0:
90
+ outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
91
+ else:
92
+ assert self.training is False, 'you have passed context, make sure that you are running inference mode'
93
+ assert context.size(2) == self.pre_lookahead_len
94
+ outputs = F.pad(torch.concat([outputs, context], dim=2), (0, self.pre_lookahead_len - context.size(2)), mode='constant', value=0.0)
95
+ outputs = F.leaky_relu(self.conv1(outputs))
96
+ # outputs
97
+ outputs = F.pad(outputs, (self.conv2.kernel_size[0] - 1, 0), mode='constant', value=0.0)
98
+ outputs = self.conv2(outputs)
99
+ outputs = outputs.transpose(1, 2).contiguous()
100
+
101
+ # residual connection
102
+ outputs = outputs + inputs
103
+ return outputs
104
+
105
+
106
+ class UpsampleConformerEncoder(torch.nn.Module):
107
+
108
+ def __init__(
109
+ self,
110
+ input_size: int,
111
+ output_size: int = 256,
112
+ attention_heads: int = 4,
113
+ linear_units: int = 2048,
114
+ num_blocks: int = 6,
115
+ dropout_rate: float = 0.1,
116
+ positional_dropout_rate: float = 0.1,
117
+ attention_dropout_rate: float = 0.0,
118
+ input_layer: str = "conv2d",
119
+ pos_enc_layer_type: str = "rel_pos",
120
+ normalize_before: bool = True,
121
+ static_chunk_size: int = 0,
122
+ use_dynamic_chunk: bool = False,
123
+ global_cmvn: torch.nn.Module = None,
124
+ use_dynamic_left_chunk: bool = False,
125
+ positionwise_conv_kernel_size: int = 1,
126
+ macaron_style: bool = True,
127
+ selfattention_layer_type: str = "rel_selfattn",
128
+ activation_type: str = "swish",
129
+ use_cnn_module: bool = True,
130
+ cnn_module_kernel: int = 15,
131
+ causal: bool = False,
132
+ cnn_module_norm: str = "batch_norm",
133
+ key_bias: bool = True,
134
+ gradient_checkpointing: bool = False,
135
+ ):
136
+ """
137
+ Args:
138
+ input_size (int): input dim
139
+ output_size (int): dimension of attention
140
+ attention_heads (int): the number of heads of multi head attention
141
+ linear_units (int): the hidden units number of position-wise feed
142
+ forward
143
+ num_blocks (int): the number of decoder blocks
144
+ dropout_rate (float): dropout rate
145
+ attention_dropout_rate (float): dropout rate in attention
146
+ positional_dropout_rate (float): dropout rate after adding
147
+ positional encoding
148
+ input_layer (str): input layer type.
149
+ optional [linear, conv2d, conv2d6, conv2d8]
150
+ pos_enc_layer_type (str): Encoder positional encoding layer type.
151
+ opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
152
+ normalize_before (bool):
153
+ True: use layer_norm before each sub-block of a layer.
154
+ False: use layer_norm after each sub-block of a layer.
155
+ static_chunk_size (int): chunk size for static chunk training and
156
+ decoding
157
+ use_dynamic_chunk (bool): whether use dynamic chunk size for
158
+ training or not, You can only use fixed chunk(chunk_size > 0)
159
+ or dyanmic chunk size(use_dynamic_chunk = True)
160
+ global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
161
+ use_dynamic_left_chunk (bool): whether use dynamic left chunk in
162
+ dynamic chunk training
163
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
164
+ gradient_checkpointing: rerunning a forward-pass segment for each
165
+ checkpointed segment during backward.
166
+ """
167
+ super().__init__()
168
+ self._output_size = output_size
169
+
170
+ self.global_cmvn = global_cmvn
171
+ self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
172
+ input_size,
173
+ output_size,
174
+ dropout_rate,
175
+ COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
176
+ positional_dropout_rate),
177
+ )
178
+
179
+ self.normalize_before = normalize_before
180
+ self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
181
+ self.static_chunk_size = static_chunk_size
182
+ self.use_dynamic_chunk = use_dynamic_chunk
183
+ self.use_dynamic_left_chunk = use_dynamic_left_chunk
184
+ self.gradient_checkpointing = gradient_checkpointing
185
+ activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
186
+ # self-attention module definition
187
+ encoder_selfattn_layer_args = (
188
+ attention_heads,
189
+ output_size,
190
+ attention_dropout_rate,
191
+ key_bias,
192
+ )
193
+ # feed-forward module definition
194
+ positionwise_layer_args = (
195
+ output_size,
196
+ linear_units,
197
+ dropout_rate,
198
+ activation,
199
+ )
200
+ # convolution module definition
201
+ convolution_layer_args = (output_size, cnn_module_kernel, activation,
202
+ cnn_module_norm, causal)
203
+ self.pre_lookahead_layer = PreLookaheadLayer(in_channels=512, channels=512, pre_lookahead_len=3)
204
+ self.encoders = torch.nn.ModuleList([
205
+ ConformerEncoderLayer(
206
+ output_size,
207
+ COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
208
+ *encoder_selfattn_layer_args),
209
+ PositionwiseFeedForward(*positionwise_layer_args),
210
+ PositionwiseFeedForward(
211
+ *positionwise_layer_args) if macaron_style else None,
212
+ ConvolutionModule(
213
+ *convolution_layer_args) if use_cnn_module else None,
214
+ dropout_rate,
215
+ normalize_before,
216
+ ) for _ in range(num_blocks)
217
+ ])
218
+ self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2)
219
+ self.up_embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
220
+ input_size,
221
+ output_size,
222
+ dropout_rate,
223
+ COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
224
+ positional_dropout_rate),
225
+ )
226
+ self.up_encoders = torch.nn.ModuleList([
227
+ ConformerEncoderLayer(
228
+ output_size,
229
+ COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
230
+ *encoder_selfattn_layer_args),
231
+ PositionwiseFeedForward(*positionwise_layer_args),
232
+ PositionwiseFeedForward(
233
+ *positionwise_layer_args) if macaron_style else None,
234
+ ConvolutionModule(
235
+ *convolution_layer_args) if use_cnn_module else None,
236
+ dropout_rate,
237
+ normalize_before,
238
+ ) for _ in range(4)
239
+ ])
240
+
241
+ def output_size(self) -> int:
242
+ return self._output_size
243
+
244
+ def forward(
245
+ self,
246
+ xs: torch.Tensor,
247
+ xs_lens: torch.Tensor,
248
+ context: torch.Tensor = torch.zeros(0, 0, 0),
249
+ decoding_chunk_size: int = 0,
250
+ num_decoding_left_chunks: int = -1,
251
+ streaming: bool = False,
252
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
253
+ """Embed positions in tensor.
254
+
255
+ Args:
256
+ xs: padded input tensor (B, T, D)
257
+ xs_lens: input length (B)
258
+ decoding_chunk_size: decoding chunk size for dynamic chunk
259
+ 0: default for training, use random dynamic chunk.
260
+ <0: for decoding, use full chunk.
261
+ >0: for decoding, use fixed chunk size as set.
262
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
263
+ the chunk size is decoding_chunk_size.
264
+ >=0: use num_decoding_left_chunks
265
+ <0: use all left chunks
266
+ Returns:
267
+ encoder output tensor xs, and subsampled masks
268
+ xs: padded output tensor (B, T' ~= T/subsample_rate, D)
269
+ masks: torch.Tensor batch padding mask after subsample
270
+ (B, 1, T' ~= T/subsample_rate)
271
+ NOTE(xcsong):
272
+ We pass the `__call__` method of the modules instead of `forward` to the
273
+ checkpointing API because `__call__` attaches all the hooks of the module.
274
+ https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
275
+ """
276
+ T = xs.size(1)
277
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
278
+ if self.global_cmvn is not None:
279
+ xs = self.global_cmvn(xs)
280
+ xs, pos_emb, masks = self.embed(xs, masks)
281
+ if context.size(1) != 0:
282
+ assert self.training is False, 'you have passed context, make sure that you are running inference mode'
283
+ context_masks = torch.ones(1, 1, context.size(1)).to(masks)
284
+ context, _, _ = self.embed(context, context_masks, offset=xs.size(1))
285
+ mask_pad = masks # (B, 1, T/subsample_rate)
286
+ chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size if streaming is True else 0, -1)
287
+ # lookahead + conformer encoder
288
+ xs = self.pre_lookahead_layer(xs, context=context)
289
+ xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
290
+
291
+ # upsample + conformer encoder
292
+ xs = xs.transpose(1, 2).contiguous()
293
+ xs, xs_lens = self.up_layer(xs, xs_lens)
294
+ xs = xs.transpose(1, 2).contiguous()
295
+ T = xs.size(1)
296
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
297
+ xs, pos_emb, masks = self.up_embed(xs, masks)
298
+ mask_pad = masks # (B, 1, T/subsample_rate)
299
+ chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size * self.up_layer.stride if streaming is True else 0, -1)
300
+ xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
301
+
302
+ if self.normalize_before:
303
+ xs = self.after_norm(xs)
304
+ # Here we assume the mask is not changed in encoder layers, so just
305
+ # return the masks before encoder layers, and the masks will be used
306
+ # for cross attention with decoder later
307
+ return xs, masks
308
+
309
+ def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
310
+ pos_emb: torch.Tensor,
311
+ mask_pad: torch.Tensor) -> torch.Tensor:
312
+ for layer in self.encoders:
313
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
314
+ return xs
315
+
316
+ def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
317
+ pos_emb: torch.Tensor,
318
+ mask_pad: torch.Tensor) -> torch.Tensor:
319
+ for layer in self.up_encoders:
320
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
321
+ return xs
cosyvoice/utils/__init__.py ADDED
File without changes
cosyvoice/utils/class_utils.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright [2023-11-28] <[email protected], Xingchen Song>
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import torch
16
+
17
+ from cosyvoice.transformer.activation import Swish
18
+ from cosyvoice.transformer.subsampling import (
19
+ LinearNoSubsampling,
20
+ EmbedinigNoSubsampling,
21
+ Conv1dSubsampling2,
22
+ Conv2dSubsampling4,
23
+ Conv2dSubsampling6,
24
+ Conv2dSubsampling8,
25
+ )
26
+ from cosyvoice.transformer.embedding import (PositionalEncoding,
27
+ RelPositionalEncoding,
28
+ WhisperPositionalEncoding,
29
+ LearnablePositionalEncoding,
30
+ NoPositionalEncoding)
31
+ from cosyvoice.transformer.attention import (MultiHeadedAttention,
32
+ RelPositionMultiHeadedAttention)
33
+ from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
34
+ from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
35
+ from cosyvoice.llm.llm import TransformerLM, Qwen2LM, CosyVoice3LM
36
+ from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec, CausalMaskedDiffWithDiT
37
+ from cosyvoice.hifigan.generator import HiFTGenerator, CausalHiFTGenerator
38
+ from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
39
+
40
+
41
+ COSYVOICE_ACTIVATION_CLASSES = {
42
+ "hardtanh": torch.nn.Hardtanh,
43
+ "tanh": torch.nn.Tanh,
44
+ "relu": torch.nn.ReLU,
45
+ "selu": torch.nn.SELU,
46
+ "swish": getattr(torch.nn, "SiLU", Swish),
47
+ "gelu": torch.nn.GELU,
48
+ }
49
+
50
+ COSYVOICE_SUBSAMPLE_CLASSES = {
51
+ "linear": LinearNoSubsampling,
52
+ "linear_legacy": LegacyLinearNoSubsampling,
53
+ "embed": EmbedinigNoSubsampling,
54
+ "conv1d2": Conv1dSubsampling2,
55
+ "conv2d": Conv2dSubsampling4,
56
+ "conv2d6": Conv2dSubsampling6,
57
+ "conv2d8": Conv2dSubsampling8,
58
+ 'paraformer_dummy': torch.nn.Identity
59
+ }
60
+
61
+ COSYVOICE_EMB_CLASSES = {
62
+ "embed": PositionalEncoding,
63
+ "abs_pos": PositionalEncoding,
64
+ "rel_pos": RelPositionalEncoding,
65
+ "rel_pos_espnet": EspnetRelPositionalEncoding,
66
+ "no_pos": NoPositionalEncoding,
67
+ "abs_pos_whisper": WhisperPositionalEncoding,
68
+ "embed_learnable_pe": LearnablePositionalEncoding,
69
+ }
70
+
71
+ COSYVOICE_ATTENTION_CLASSES = {
72
+ "selfattn": MultiHeadedAttention,
73
+ "rel_selfattn": RelPositionMultiHeadedAttention,
74
+ }
75
+
76
+
77
+ def get_model_type(configs):
78
+ # NOTE CosyVoice2Model inherits CosyVoiceModel
79
+ if isinstance(configs['llm'], TransformerLM) and isinstance(configs['flow'], MaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
80
+ return CosyVoiceModel
81
+ if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
82
+ return CosyVoice2Model
83
+ if isinstance(configs['llm'], CosyVoice3LM) and isinstance(configs['flow'], CausalMaskedDiffWithDiT) and isinstance(configs['hift'], CausalHiFTGenerator):
84
+ return CosyVoice3Model
85
+ raise TypeError('No valid model type found!')
cosyvoice/utils/common.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ # 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # Modified from ESPnet(https://github.com/espnet/espnet)
17
+ """Unility functions for Transformer."""
18
+
19
+ import queue
20
+ import random
21
+ from typing import List
22
+
23
+ import numpy as np
24
+ import torch
25
+
26
+ IGNORE_ID = -1
27
+
28
+ instruct_list = ["You are a helpful assistant. 请用广东话表达。<|endofprompt|>",
29
+ "You are a helpful assistant. 请用东北话表达。<|endofprompt|>",
30
+ "You are a helpful assistant. 请用甘肃话表达。<|endofprompt|>",
31
+ "You are a helpful assistant. 请用贵州话表达。<|endofprompt|>",
32
+ "You are a helpful assistant. 请用河南话表达。<|endofprompt|>",
33
+ "You are a helpful assistant. 请用湖北话表达。<|endofprompt|>",
34
+ "You are a helpful assistant. 请用湖南话表达。<|endofprompt|>",
35
+ "You are a helpful assistant. 请用江西话表达。<|endofprompt|>",
36
+ "You are a helpful assistant. 请用闽南话表达。<|endofprompt|>",
37
+ "You are a helpful assistant. 请用宁夏话表达。<|endofprompt|>",
38
+ "You are a helpful assistant. 请用山西话表达。<|endofprompt|>",
39
+ "You are a helpful assistant. 请用陕西话表达。<|endofprompt|>",
40
+ "You are a helpful assistant. 请用山东话表达。<|endofprompt|>",
41
+ "You are a helpful assistant. 请用上海话表达。<|endofprompt|>",
42
+ "You are a helpful assistant. 请用四川话表达。<|endofprompt|>",
43
+ "You are a helpful assistant. 请用天津话表达。<|endofprompt|>",
44
+ "You are a helpful assistant. 请用云南话表达。<|endofprompt|>",
45
+ "You are a helpful assistant. Please say a sentence as loudly as possible.<|endofprompt|>",
46
+ "You are a helpful assistant. Please say a sentence in a very soft voice.<|endofprompt|>",
47
+ "You are a helpful assistant. 请用尽可能慢地语速说一句话。<|endofprompt|>",
48
+ "You are a helpful assistant. 请用尽可能快地语速说一句话。<|endofprompt|>",
49
+ "You are a helpful assistant. 请非常开心地说一句话。<|endofprompt|>",
50
+ "You are a helpful assistant. 请非常伤心地说一句话。<|endofprompt|>",
51
+ "You are a helpful assistant. 请非常生气地说一句话。<|endofprompt|>",
52
+ "You are a helpful assistant. 我想体验一下小猪佩奇风格,可以吗?<|endofprompt|>",
53
+ "You are a helpful assistant. 你可以尝试用机器人的方式解答吗?<|endofprompt|>"]
54
+
55
+
56
+ def pad_list(xs: List[torch.Tensor], pad_value: int):
57
+ """Perform padding for the list of tensors.
58
+
59
+ Args:
60
+ xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
61
+ pad_value (float): Value for padding.
62
+
63
+ Returns:
64
+ Tensor: Padded tensor (B, Tmax, `*`).
65
+
66
+ Examples:
67
+ >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
68
+ >>> x
69
+ [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
70
+ >>> pad_list(x, 0)
71
+ tensor([[1., 1., 1., 1.],
72
+ [1., 1., 0., 0.],
73
+ [1., 0., 0., 0.]])
74
+
75
+ """
76
+ max_len = max([len(item) for item in xs])
77
+ batchs = len(xs)
78
+ ndim = xs[0].ndim
79
+ if ndim == 1:
80
+ pad_res = torch.zeros(batchs,
81
+ max_len,
82
+ dtype=xs[0].dtype,
83
+ device=xs[0].device)
84
+ elif ndim == 2:
85
+ pad_res = torch.zeros(batchs,
86
+ max_len,
87
+ xs[0].shape[1],
88
+ dtype=xs[0].dtype,
89
+ device=xs[0].device)
90
+ elif ndim == 3:
91
+ pad_res = torch.zeros(batchs,
92
+ max_len,
93
+ xs[0].shape[1],
94
+ xs[0].shape[2],
95
+ dtype=xs[0].dtype,
96
+ device=xs[0].device)
97
+ else:
98
+ raise ValueError(f"Unsupported ndim: {ndim}")
99
+ pad_res.fill_(pad_value)
100
+ for i in range(batchs):
101
+ pad_res[i, :len(xs[i])] = xs[i]
102
+ return pad_res
103
+
104
+
105
+ def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor,
106
+ ignore_label: int) -> torch.Tensor:
107
+ """Calculate accuracy.
108
+
109
+ Args:
110
+ pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
111
+ pad_targets (LongTensor): Target label tensors (B, Lmax).
112
+ ignore_label (int): Ignore label id.
113
+
114
+ Returns:
115
+ torch.Tensor: Accuracy value (0.0 - 1.0).
116
+
117
+ """
118
+ pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1),
119
+ pad_outputs.size(1)).argmax(2)
120
+ mask = pad_targets != ignore_label
121
+ numerator = torch.sum(
122
+ pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
123
+ denominator = torch.sum(mask)
124
+ return (numerator / denominator).detach()
125
+
126
+
127
+ def get_padding(kernel_size, dilation=1):
128
+ return int((kernel_size * dilation - dilation) / 2)
129
+
130
+
131
+ def init_weights(m, mean=0.0, std=0.01):
132
+ classname = m.__class__.__name__
133
+ if classname.find("Conv") != -1:
134
+ m.weight.data.normal_(mean, std)
135
+
136
+
137
+ # Repetition Aware Sampling in VALL-E 2
138
+ def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1):
139
+ top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
140
+ rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item()
141
+ if rep_num >= win_size * tau_r:
142
+ top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)
143
+ return top_ids
144
+
145
+
146
+ def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
147
+ prob, indices = [], []
148
+ cum_prob = 0.0
149
+ sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True)
150
+ for i in range(len(sorted_idx)):
151
+ # sampling both top-p and numbers.
152
+ if cum_prob < top_p and len(prob) < top_k:
153
+ cum_prob += sorted_value[i]
154
+ prob.append(sorted_value[i])
155
+ indices.append(sorted_idx[i])
156
+ else:
157
+ break
158
+ prob = torch.tensor(prob).to(weighted_scores)
159
+ indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
160
+ top_ids = indices[prob.multinomial(1, replacement=True)].item()
161
+ return top_ids
162
+
163
+
164
+ def random_sampling(weighted_scores, decoded_tokens, sampling):
165
+ top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True).item()
166
+ return top_ids
167
+
168
+
169
+ def fade_in_out(fade_in_mel, fade_out_mel, window):
170
+ device = fade_in_mel.device
171
+ fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
172
+ mel_overlap_len = int(window.shape[0] / 2)
173
+ if fade_in_mel.device == torch.device('cpu'):
174
+ fade_in_mel = fade_in_mel.clone()
175
+ fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
176
+ fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
177
+ return fade_in_mel.to(device)
178
+
179
+
180
+ def set_all_random_seed(seed):
181
+ random.seed(seed)
182
+ np.random.seed(seed)
183
+ torch.manual_seed(seed)
184
+ torch.cuda.manual_seed_all(seed)
185
+
186
+
187
+ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
188
+ assert mask.dtype == torch.bool
189
+ assert dtype in [torch.float32, torch.bfloat16, torch.float16]
190
+ mask = mask.to(dtype)
191
+ # attention mask bias
192
+ # NOTE(Mddct): torch.finfo jit issues
193
+ # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
194
+ mask = (1.0 - mask) * -1.0e+10
195
+ return mask
196
+
197
+
198
+ class TrtContextWrapper:
199
+ def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
200
+ self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
201
+ self.trt_engine = trt_engine
202
+ for _ in range(trt_concurrent):
203
+ trt_context = trt_engine.create_execution_context()
204
+ trt_stream = torch.cuda.stream(torch.cuda.Stream(device))
205
+ assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
206
+ self.trt_context_pool.put([trt_context, trt_stream])
207
+ assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
208
+
209
+ def acquire_estimator(self):
210
+ return self.trt_context_pool.get(), self.trt_engine
211
+
212
+ def release_estimator(self, context, stream):
213
+ self.trt_context_pool.put([context, stream])
cosyvoice/utils/executor.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ from contextlib import nullcontext
18
+ import os
19
+
20
+ import torch
21
+ import torch.distributed as dist
22
+
23
+ from cosyvoice.utils.train_utils import update_parameter_and_lr, log_per_step, log_per_save, batch_forward, batch_backward, save_model, cosyvoice_join
24
+
25
+
26
+ class Executor:
27
+
28
+ def __init__(self, gan: bool = False, ref_model: torch.nn.Module = None, dpo_loss: torch.nn.Module = None):
29
+ self.gan = gan
30
+ self.ref_model = ref_model
31
+ self.dpo_loss = dpo_loss
32
+ self.step = 0
33
+ self.epoch = 0
34
+ self.rank = int(os.environ.get('RANK', 0))
35
+ self.device = torch.device('cuda:{}'.format(self.rank))
36
+
37
+ def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=None):
38
+ ''' Train one epoch
39
+ '''
40
+
41
+ lr = optimizer.param_groups[0]['lr']
42
+ logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
43
+ logging.info('using accumulate grad, new batch size is {} times'
44
+ ' larger than before'.format(info_dict['accum_grad']))
45
+ # A context manager to be used in conjunction with an instance of
46
+ # torch.nn.parallel.DistributedDataParallel to be able to train
47
+ # with uneven inputs across participating processes.
48
+ model.train()
49
+ if self.ref_model is not None:
50
+ self.ref_model.eval()
51
+ model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
52
+ with model_context():
53
+ for batch_idx, batch_dict in enumerate(train_data_loader):
54
+ info_dict["tag"] = "TRAIN"
55
+ info_dict["step"] = self.step
56
+ info_dict["epoch"] = self.epoch
57
+ info_dict["batch_idx"] = batch_idx
58
+ if cosyvoice_join(group_join, info_dict):
59
+ break
60
+
61
+ # Disable gradient synchronizations across DDP processes.
62
+ # Within this context, gradients will be accumulated on module
63
+ # variables, which will later be synchronized.
64
+ if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
65
+ context = model.no_sync
66
+ # Used for single gpu training and DDP gradient synchronization
67
+ # processes.
68
+ else:
69
+ context = nullcontext
70
+
71
+ with context():
72
+ info_dict = batch_forward(model, batch_dict, scaler, info_dict, ref_model=self.ref_model, dpo_loss=self.dpo_loss)
73
+ info_dict = batch_backward(model, scaler, info_dict)
74
+
75
+ info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
76
+ log_per_step(writer, info_dict)
77
+ # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
78
+ if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
79
+ (batch_idx + 1) % info_dict["accum_grad"] == 0:
80
+ dist.barrier()
81
+ self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
82
+ model.train()
83
+ if (batch_idx + 1) % info_dict["accum_grad"] == 0:
84
+ self.step += 1
85
+ dist.barrier()
86
+ self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
87
+
88
+ def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
89
+ writer, info_dict, scaler, group_join):
90
+ ''' Train one epoch
91
+ '''
92
+
93
+ lr = optimizer.param_groups[0]['lr']
94
+ logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
95
+ logging.info('using accumulate grad, new batch size is {} times'
96
+ ' larger than before'.format(info_dict['accum_grad']))
97
+ # A context manager to be used in conjunction with an instance of
98
+ # torch.nn.parallel.DistributedDataParallel to be able to train
99
+ # with uneven inputs across participating processes.
100
+ model.train()
101
+ model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
102
+ with model_context():
103
+ for batch_idx, batch_dict in enumerate(train_data_loader):
104
+ info_dict["tag"] = "TRAIN"
105
+ info_dict["step"] = self.step
106
+ info_dict["epoch"] = self.epoch
107
+ info_dict["batch_idx"] = batch_idx
108
+ if cosyvoice_join(group_join, info_dict):
109
+ break
110
+
111
+ # Disable gradient synchronizations across DDP processes.
112
+ # Within this context, gradients will be accumulated on module
113
+ # variables, which will later be synchronized.
114
+ if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
115
+ context = model.no_sync
116
+ # Used for single gpu training and DDP gradient synchronization
117
+ # processes.
118
+ else:
119
+ context = nullcontext
120
+
121
+ with context():
122
+ batch_dict['turn'] = 'discriminator'
123
+ info_dict = batch_forward(model, batch_dict, scaler, info_dict)
124
+ info_dict = batch_backward(model, scaler, info_dict)
125
+ info_dict = update_parameter_and_lr(model, optimizer_d, scheduler_d, scaler, info_dict)
126
+ optimizer.zero_grad()
127
+ log_per_step(writer, info_dict)
128
+ with context():
129
+ batch_dict['turn'] = 'generator'
130
+ info_dict = batch_forward(model, batch_dict, scaler, info_dict)
131
+ info_dict = batch_backward(model, scaler, info_dict)
132
+ info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
133
+ optimizer_d.zero_grad()
134
+ log_per_step(writer, info_dict)
135
+ # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
136
+ if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
137
+ (batch_idx + 1) % info_dict["accum_grad"] == 0:
138
+ dist.barrier()
139
+ self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
140
+ model.train()
141
+ if (batch_idx + 1) % info_dict["accum_grad"] == 0:
142
+ self.step += 1
143
+ dist.barrier()
144
+ self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
145
+
146
+ @torch.inference_mode()
147
+ def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True):
148
+ ''' Cross validation on
149
+ '''
150
+ logging.info('Epoch {} Step {} on_batch_end {} CV rank {}'.format(self.epoch, self.step + 1, on_batch_end, self.rank))
151
+ model.eval()
152
+ total_num_utts, total_loss_dict = 0, {} # avoid division by 0
153
+ for batch_idx, batch_dict in enumerate(cv_data_loader):
154
+ info_dict["tag"] = "CV"
155
+ info_dict["step"] = self.step
156
+ info_dict["epoch"] = self.epoch
157
+ info_dict["batch_idx"] = batch_idx
158
+
159
+ num_utts = len(batch_dict["utts"])
160
+ total_num_utts += num_utts
161
+
162
+ if self.gan is True:
163
+ batch_dict['turn'] = 'generator'
164
+ info_dict = batch_forward(model, batch_dict, None, info_dict)
165
+
166
+ for k, v in info_dict['loss_dict'].items():
167
+ if k not in total_loss_dict:
168
+ total_loss_dict[k] = []
169
+ total_loss_dict[k].append(v.mean().item() * num_utts)
170
+ log_per_step(None, info_dict)
171
+ for k, v in total_loss_dict.items():
172
+ total_loss_dict[k] = sum(v) / total_num_utts
173
+ info_dict['loss_dict'] = total_loss_dict
174
+ log_per_save(writer, info_dict)
175
+ model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step + 1)
176
+ save_model(model, model_name, info_dict)
cosyvoice/utils/file_utils.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu, Zetao Hu)
3
+ # 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import os
18
+ import json
19
+ import torch
20
+ import torchaudio
21
+ import logging
22
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
23
+ logging.basicConfig(level=logging.DEBUG,
24
+ format='%(asctime)s %(levelname)s %(message)s')
25
+
26
+
27
+ def read_lists(list_file):
28
+ lists = []
29
+ with open(list_file, 'r', encoding='utf8') as fin:
30
+ for line in fin:
31
+ lists.append(line.strip())
32
+ return lists
33
+
34
+
35
+ def read_json_lists(list_file):
36
+ lists = read_lists(list_file)
37
+ results = {}
38
+ for fn in lists:
39
+ with open(fn, 'r', encoding='utf8') as fin:
40
+ results.update(json.load(fin))
41
+ return results
42
+
43
+
44
+ def load_wav(wav, target_sr, min_sr=16000):
45
+ speech, sample_rate = torchaudio.load(wav, backend='soundfile')
46
+ speech = speech.mean(dim=0, keepdim=True)
47
+ if sample_rate != target_sr:
48
+ assert sample_rate >= min_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
49
+ speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
50
+ return speech
51
+
52
+
53
+ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
54
+ import tensorrt as trt
55
+ logging.info("Converting onnx to trt...")
56
+ network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
57
+ logger = trt.Logger(trt.Logger.INFO)
58
+ builder = trt.Builder(logger)
59
+ network = builder.create_network(network_flags)
60
+ parser = trt.OnnxParser(network, logger)
61
+ config = builder.create_builder_config()
62
+ config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB
63
+ if fp16:
64
+ config.set_flag(trt.BuilderFlag.FP16)
65
+ profile = builder.create_optimization_profile()
66
+ # load onnx model
67
+ with open(onnx_model, "rb") as f:
68
+ if not parser.parse(f.read()):
69
+ for error in range(parser.num_errors):
70
+ print(parser.get_error(error))
71
+ raise ValueError('failed to parse {}'.format(onnx_model))
72
+ # set input shapes
73
+ for i in range(len(trt_kwargs['input_names'])):
74
+ profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
75
+ tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
76
+ # set input and output data type
77
+ for i in range(network.num_inputs):
78
+ input_tensor = network.get_input(i)
79
+ input_tensor.dtype = tensor_dtype
80
+ for i in range(network.num_outputs):
81
+ output_tensor = network.get_output(i)
82
+ output_tensor.dtype = tensor_dtype
83
+ config.add_optimization_profile(profile)
84
+ engine_bytes = builder.build_serialized_network(network, config)
85
+ # save trt engine
86
+ with open(trt_model, "wb") as f:
87
+ f.write(engine_bytes)
88
+ logging.info("Succesfully convert onnx to trt...")
89
+
90
+
91
+ # NOTE do not support bistream inference as only speech token embedding/head is kept
92
+ def export_cosyvoice2_vllm(model, model_path, device):
93
+ if os.path.exists(model_path):
94
+ return
95
+
96
+ dtype = torch.bfloat16
97
+ # lm_head
98
+ use_bias = True if model.llm_decoder.bias is not None else False
99
+ model.llm.model.lm_head = model.llm_decoder
100
+ # embed_tokens
101
+ embed_tokens = model.llm.model.model.embed_tokens
102
+ model.llm.model.set_input_embeddings(model.speech_embedding)
103
+ model.llm.model.to(device)
104
+ model.llm.model.to(dtype)
105
+ tmp_vocab_size = model.llm.model.config.vocab_size
106
+ tmp_tie_embedding = model.llm.model.config.tie_word_embeddings
107
+ del model.llm.model.generation_config.eos_token_id
108
+ del model.llm.model.config.bos_token_id
109
+ del model.llm.model.config.eos_token_id
110
+ model.llm.model.config.vocab_size = model.speech_embedding.num_embeddings
111
+ model.llm.model.config.tie_word_embeddings = False
112
+ model.llm.model.config.use_bias = use_bias
113
+ model.llm.model.save_pretrained(model_path)
114
+ if use_bias is True:
115
+ os.system('sed -i s@Qwen2ForCausalLM@CosyVoice2ForCausalLM@g {}/config.json'.format(os.path.abspath(model_path)))
116
+ model.llm.model.config.vocab_size = tmp_vocab_size
117
+ model.llm.model.config.tie_word_embeddings = tmp_tie_embedding
118
+ model.llm.model.set_input_embeddings(embed_tokens)
cosyvoice/utils/frontend_utils.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ import regex
17
+ chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
18
+
19
+
20
+ # whether contain chinese character
21
+ def contains_chinese(text):
22
+ return bool(chinese_char_pattern.search(text))
23
+
24
+
25
+ # replace special symbol
26
+ def replace_corner_mark(text):
27
+ text = text.replace('²', '平方')
28
+ text = text.replace('³', '立方')
29
+ return text
30
+
31
+
32
+ # remove meaningless symbol
33
+ def remove_bracket(text):
34
+ text = text.replace('(', '').replace(')', '')
35
+ text = text.replace('【', '').replace('】', '')
36
+ text = text.replace('`', '').replace('`', '')
37
+ text = text.replace("——", " ")
38
+ return text
39
+
40
+
41
+ # spell Arabic numerals
42
+ def spell_out_number(text: str, inflect_parser):
43
+ new_text = []
44
+ st = None
45
+ for i, c in enumerate(text):
46
+ if not c.isdigit():
47
+ if st is not None:
48
+ num_str = inflect_parser.number_to_words(text[st: i])
49
+ new_text.append(num_str)
50
+ st = None
51
+ new_text.append(c)
52
+ else:
53
+ if st is None:
54
+ st = i
55
+ if st is not None and st < len(text):
56
+ num_str = inflect_parser.number_to_words(text[st:])
57
+ new_text.append(num_str)
58
+ return ''.join(new_text)
59
+
60
+
61
+ # split paragrah logic:
62
+ # 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
63
+ # 2. cal sentence len according to lang
64
+ # 3. split sentence according to puncatation
65
+ def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False):
66
+ def calc_utt_length(_text: str):
67
+ if lang == "zh":
68
+ return len(_text)
69
+ else:
70
+ return len(tokenize(_text))
71
+
72
+ def should_merge(_text: str):
73
+ if lang == "zh":
74
+ return len(_text) < merge_len
75
+ else:
76
+ return len(tokenize(_text)) < merge_len
77
+
78
+ if lang == "zh":
79
+ pounc = ['。', '?', '!', ';', ':', '、', '.', '?', '!', ';']
80
+ else:
81
+ pounc = ['.', '?', '!', ';', ':']
82
+ if comma_split:
83
+ pounc.extend([',', ','])
84
+
85
+ if text[-1] not in pounc:
86
+ if lang == "zh":
87
+ text += "。"
88
+ else:
89
+ text += "."
90
+
91
+ st = 0
92
+ utts = []
93
+ for i, c in enumerate(text):
94
+ if c in pounc:
95
+ if len(text[st: i]) > 0:
96
+ utts.append(text[st: i] + c)
97
+ if i + 1 < len(text) and text[i + 1] in ['"', '”']:
98
+ tmp = utts.pop(-1)
99
+ utts.append(tmp + text[i + 1])
100
+ st = i + 2
101
+ else:
102
+ st = i + 1
103
+
104
+ final_utts = []
105
+ cur_utt = ""
106
+ for utt in utts:
107
+ if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n:
108
+ final_utts.append(cur_utt)
109
+ cur_utt = ""
110
+ cur_utt = cur_utt + utt
111
+ if len(cur_utt) > 0:
112
+ if should_merge(cur_utt) and len(final_utts) != 0:
113
+ final_utts[-1] = final_utts[-1] + cur_utt
114
+ else:
115
+ final_utts.append(cur_utt)
116
+
117
+ return final_utts
118
+
119
+
120
+ # remove blank between chinese character
121
+ def replace_blank(text: str):
122
+ out_str = []
123
+ for i, c in enumerate(text):
124
+ if c == " ":
125
+ if ((text[i + 1].isascii() and text[i + 1] != " ") and
126
+ (text[i - 1].isascii() and text[i - 1] != " ")):
127
+ out_str.append(c)
128
+ else:
129
+ out_str.append(c)
130
+ return "".join(out_str)
131
+
132
+
133
+ def is_only_punctuation(text):
134
+ # Regular expression: Match strings that consist only of punctuation marks or are empty.
135
+ punctuation_pattern = r'^[\p{P}\p{S}]*$'
136
+ return bool(regex.fullmatch(punctuation_pattern, text))
cosyvoice/utils/losses.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from typing import Tuple
4
+
5
+
6
+ def tpr_loss(disc_real_outputs, disc_generated_outputs, tau):
7
+ loss = 0
8
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
9
+ m_DG = torch.median((dr - dg))
10
+ L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG])
11
+ loss += tau - F.relu(tau - L_rel)
12
+ return loss
13
+
14
+
15
+ def mel_loss(real_speech, generated_speech, mel_transforms):
16
+ loss = 0
17
+ for transform in mel_transforms:
18
+ mel_r = transform(real_speech)
19
+ mel_g = transform(generated_speech)
20
+ loss += F.l1_loss(mel_g, mel_r)
21
+ return loss
22
+
23
+
24
+ class DPOLoss(torch.nn.Module):
25
+ """
26
+ DPO Loss
27
+ """
28
+
29
+ def __init__(self, beta: float, label_smoothing: float = 0.0, ipo: bool = False) -> None:
30
+ super().__init__()
31
+ self.beta = beta
32
+ self.label_smoothing = label_smoothing
33
+ self.ipo = ipo
34
+
35
+ def forward(
36
+ self,
37
+ policy_chosen_logps: torch.Tensor,
38
+ policy_rejected_logps: torch.Tensor,
39
+ reference_chosen_logps: torch.Tensor,
40
+ reference_rejected_logps: torch.Tensor,
41
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
42
+ pi_logratios = policy_chosen_logps - policy_rejected_logps
43
+ ref_logratios = reference_chosen_logps - reference_rejected_logps
44
+ logits = pi_logratios - ref_logratios
45
+ if self.ipo:
46
+ losses = (logits - 1 / (2 * self.beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf
47
+ else:
48
+ # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf)
49
+ losses = (
50
+ -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
51
+ - F.logsigmoid(-self.beta * logits) * self.label_smoothing
52
+ )
53
+ loss = losses.mean()
54
+ chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
55
+ rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
56
+
57
+ return loss, chosen_rewards, rejected_rewards