27
27
28
28
start_time = datetime .now ()
29
29
30
+
30
31
class TTS (nn .Module ):
31
- def __init__ (self ,
32
- language ,
33
- device = 'auto' ,
34
- use_hf = True ,
35
- config_path = None ,
36
- ckpt_path = None ):
32
+ def __init__ (
33
+ self , language , device = "auto" , use_hf = True , config_path = None , ckpt_path = None
34
+ ):
37
35
super ().__init__ ()
38
- if device == 'auto' :
39
- device = 'cpu'
40
- if torch .cuda .is_available (): device = 'cuda'
41
- if torch .backends .mps .is_available (): device = 'mps'
42
- if 'cuda' in device :
36
+ if device == "auto" :
37
+ device = "cpu"
38
+ if torch .cuda .is_available ():
39
+ device = "cuda"
40
+ if torch .backends .mps .is_available ():
41
+ device = "mps"
42
+ if "cuda" in device :
43
43
assert torch .cuda .is_available ()
44
44
45
- # config_path =
45
+ # config_path =
46
46
hps = load_or_download_config (language , use_hf = use_hf , config_path = config_path )
47
47
48
48
num_languages = hps .num_languages
@@ -64,16 +64,20 @@ def __init__(self,
64
64
self .symbol_to_id = {s : i for i , s in enumerate (symbols )}
65
65
self .hps = hps
66
66
self .device = device
67
-
67
+
68
68
# load state_dict
69
- checkpoint_dict = load_or_download_model (language , device , use_hf = use_hf , ckpt_path = ckpt_path )
70
- self .model .load_state_dict (checkpoint_dict ['model' ], strict = True )
71
-
72
- language = language .split ('_' )[0 ]
73
- self .language = 'ZH_MIX_EN' if language == 'ZH' else language # we support a ZH_MIX_EN model
69
+ checkpoint_dict = load_or_download_model (
70
+ language , device , use_hf = use_hf , ckpt_path = ckpt_path
71
+ )
72
+ self .model .load_state_dict (checkpoint_dict ["model" ], strict = True )
73
+
74
+ language = language .split ("_" )[0 ]
75
+ self .language = (
76
+ "ZH_MIX_EN" if language == "ZH" else language
77
+ ) # we support a ZH_MIX_EN model
74
78
75
79
@staticmethod
76
- def audio_numpy_concat (segment_data_list , sr , speed = 1. ):
80
+ def audio_numpy_concat (segment_data_list , sr , speed = 1.0 ):
77
81
audio_segments = []
78
82
for segment_data in segment_data_list :
79
83
audio_segments += segment_data .reshape (- 1 ).tolist ()
@@ -86,11 +90,24 @@ def split_sentences_into_pieces(text, language, quiet=False):
86
90
texts = split_sentence (text , language_str = language )
87
91
if not quiet :
88
92
print (" > Text split to sentences." )
89
- print (' \n ' .join (texts ))
93
+ print (" \n " .join (texts ))
90
94
print (" > ===========================" )
91
95
return texts
92
96
93
- def tts_to_file (self , text , speaker_id , output_path = None , sdp_ratio = 0.2 , noise_scale = 0.6 , noise_scale_w = 0.8 , speed = 1.0 , pbar = None , format = None , position = None , quiet = False ,):
97
+ def tts_to_file (
98
+ self ,
99
+ text ,
100
+ speaker_id ,
101
+ output_path = None ,
102
+ sdp_ratio = 0.2 ,
103
+ noise_scale = 0.6 ,
104
+ noise_scale_w = 0.8 ,
105
+ speed = 1.0 ,
106
+ pbar = None ,
107
+ format = None ,
108
+ position = None ,
109
+ quiet = False ,
110
+ ):
94
111
language = self .language
95
112
texts = self .split_sentences_into_pieces (text , language , quiet )
96
113
audio_list = []
@@ -104,10 +121,12 @@ def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_s
104
121
else :
105
122
tx = tqdm (texts )
106
123
for t in tx :
107
- if language in ['EN' , ' ZH_MIX_EN' ]:
108
- t = re .sub (r' ([a-z])([A-Z])' , r' \1 \2' , t )
124
+ if language in ["EN" , " ZH_MIX_EN" ]:
125
+ t = re .sub (r" ([a-z])([A-Z])" , r" \1 \2" , t )
109
126
device = self .device
110
- bert , ja_bert , phones , tones , lang_ids = utils .get_text_for_tts_infer (t , language , self .hps , device , self .symbol_to_id )
127
+ bert , ja_bert , phones , tones , lang_ids = utils .get_text_for_tts_infer (
128
+ t , language , self .hps , device , self .symbol_to_id
129
+ )
111
130
with torch .no_grad ():
112
131
x_tst = phones .to (device ).unsqueeze (0 )
113
132
tones = tones .to (device ).unsqueeze (0 )
@@ -117,7 +136,8 @@ def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_s
117
136
x_tst_lengths = torch .LongTensor ([phones .size (0 )]).to (device )
118
137
del phones
119
138
speakers = torch .LongTensor ([speaker_id ]).to (device )
120
- audio = self .model .infer (
139
+ audio = (
140
+ self .model .infer (
121
141
x_tst ,
122
142
x_tst_lengths ,
123
143
speakers ,
@@ -128,26 +148,43 @@ def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_s
128
148
sdp_ratio = sdp_ratio ,
129
149
noise_scale = noise_scale ,
130
150
noise_scale_w = noise_scale_w ,
131
- length_scale = 1. / speed ,
132
- )[0 ][0 , 0 ].data .cpu ().float ().numpy ()
151
+ length_scale = 1.0 / speed ,
152
+ )[0 ][0 , 0 ]
153
+ .data .cpu ()
154
+ .float ()
155
+ .numpy ()
156
+ )
133
157
del x_tst , tones , lang_ids , bert , ja_bert , x_tst_lengths , speakers
134
- #
158
+ #
135
159
audio_list .append (audio )
136
160
torch .cuda .empty_cache ()
137
- audio = self .audio_numpy_concat (audio_list , sr = self .hps .data .sampling_rate , speed = speed )
161
+ audio = self .audio_numpy_concat (
162
+ audio_list , sr = self .hps .data .sampling_rate , speed = speed
163
+ )
138
164
139
165
if output_path is None :
140
166
return audio
141
167
else :
142
168
if format :
143
- soundfile .write (output_path , audio , self .hps .data .sampling_rate , format = format )
169
+ soundfile .write (
170
+ output_path , audio , self .hps .data .sampling_rate , format = format
171
+ )
144
172
else :
145
173
soundfile .write (output_path , audio , self .hps .data .sampling_rate )
146
174
147
-
148
-
149
-
150
- def tts_to_base64 (self , text , speaker_id , sdp_ratio = 0.2 , noise_scale = 0.6 , noise_scale_w = 0.8 , speed = 1.0 , pbar = None , format = None , position = None , quiet = False ,):
175
+ def old_tts_to_base64 (
176
+ self ,
177
+ text ,
178
+ speaker_id ,
179
+ sdp_ratio = 0.2 ,
180
+ noise_scale = 0.6 ,
181
+ noise_scale_w = 0.8 ,
182
+ speed = 1.0 ,
183
+ pbar = None ,
184
+ format = None ,
185
+ position = None ,
186
+ quiet = False ,
187
+ ):
151
188
language = self .language
152
189
texts = self .split_sentences_into_pieces (text , language , quiet )
153
190
audio_list = []
@@ -161,10 +198,12 @@ def tts_to_base64(self, text, speaker_id, sdp_ratio=0.2, noise_scale=0.6, noise_
161
198
else :
162
199
tx = tqdm (texts )
163
200
for t in tx :
164
- if language in ['EN' , ' ZH_MIX_EN' ]:
165
- t = re .sub (r' ([a-z])([A-Z])' , r' \1 \2' , t )
201
+ if language in ["EN" , " ZH_MIX_EN" ]:
202
+ t = re .sub (r" ([a-z])([A-Z])" , r" \1 \2" , t )
166
203
device = self .device
167
- bert , ja_bert , phones , tones , lang_ids = utils .get_text_for_tts_infer (t , language , self .hps , device , self .symbol_to_id )
204
+ bert , ja_bert , phones , tones , lang_ids = utils .get_text_for_tts_infer (
205
+ t , language , self .hps , device , self .symbol_to_id
206
+ )
168
207
with torch .no_grad ():
169
208
x_tst = phones .to (device ).unsqueeze (0 )
170
209
tones = tones .to (device ).unsqueeze (0 )
@@ -174,7 +213,8 @@ def tts_to_base64(self, text, speaker_id, sdp_ratio=0.2, noise_scale=0.6, noise_
174
213
x_tst_lengths = torch .LongTensor ([phones .size (0 )]).to (device )
175
214
del phones
176
215
speakers = torch .LongTensor ([speaker_id ]).to (device )
177
- audio = self .model .infer (
216
+ audio = (
217
+ self .model .infer (
178
218
x_tst ,
179
219
x_tst_lengths ,
180
220
speakers ,
@@ -185,26 +225,149 @@ def tts_to_base64(self, text, speaker_id, sdp_ratio=0.2, noise_scale=0.6, noise_
185
225
sdp_ratio = sdp_ratio ,
186
226
noise_scale = noise_scale ,
187
227
noise_scale_w = noise_scale_w ,
188
- length_scale = 1. / speed ,
189
- )[0 ][0 , 0 ].data .cpu ().float ().numpy ()
228
+ length_scale = 1.0 / speed ,
229
+ )[0 ][0 , 0 ]
230
+ .data .cpu ()
231
+ .float ()
232
+ .numpy ()
233
+ )
190
234
del x_tst , tones , lang_ids , bert , ja_bert , x_tst_lengths , speakers
191
- #
235
+ #
192
236
audio_list .append (audio )
193
237
torch .cuda .empty_cache ()
194
- audio = self .audio_numpy_concat (audio_list , sr = self .hps .data .sampling_rate , speed = speed )
238
+ audio = self .audio_numpy_concat (
239
+ audio_list , sr = self .hps .data .sampling_rate , speed = speed
240
+ )
195
241
196
242
with io .BytesIO () as wav_buffer :
197
- soundfile .write (wav_buffer , audio , self .hps .data .sampling_rate , format = "WAV" )
243
+ soundfile .write (
244
+ wav_buffer , audio , self .hps .data .sampling_rate , format = "WAV"
245
+ )
198
246
wav_buffer .seek (0 )
199
247
wav_bytes = wav_buffer .read ()
200
248
201
-
202
249
wav_base64 = base64 .b64encode (wav_bytes ).decode ("utf-8" )
203
250
end_time = datetime .now ()
204
251
elapsed_time = end_time - start_time
205
252
206
- return jsonable_encoder ({
207
- "audioContent" : wav_base64 ,
208
- "time_taken" : elapsed_time
209
- })
253
+ return jsonable_encoder (
254
+ {"audioContent" : wav_base64 , "time_taken" : elapsed_time }
255
+ )
256
+
257
+ def tts_iter (
258
+ self ,
259
+ text ,
260
+ speaker_id ,
261
+ sdp_ratio = 0.2 ,
262
+ noise_scale = 0.6 ,
263
+ noise_scale_w = 0.8 ,
264
+ speed = 1.0 ,
265
+ pbar = None ,
266
+ position = None ,
267
+ quiet = False ,
268
+ ):
269
+ """
270
+ https://github.com/myshell-ai/MeloTTS/pull/88/files
271
+ """
272
+ language = self .language
273
+ texts = self .split_sentences_into_pieces (text , language , quiet )
274
+
275
+ if pbar :
276
+ tx = pbar (texts )
277
+ else :
278
+ if position :
279
+ tx = tqdm (texts , position = position )
280
+ elif quiet :
281
+ tx = texts
282
+ else :
283
+ tx = tqdm (texts )
284
+ for t in tx :
285
+ if language in ["EN" , "ZH_MIX_EN" ]:
286
+ t = re .sub (r"([a-z])([A-Z])" , r"\1 \2" , t )
287
+ device = self .device
288
+ bert , ja_bert , phones , tones , lang_ids = utils .get_text_for_tts_infer (
289
+ t , language , self .hps , device , self .symbol_to_id
290
+ )
291
+ with torch .no_grad ():
292
+ x_tst = phones .to (device ).unsqueeze (0 )
293
+ tones = tones .to (device ).unsqueeze (0 )
294
+ lang_ids = lang_ids .to (device ).unsqueeze (0 )
295
+ bert = bert .to (device ).unsqueeze (0 )
296
+ ja_bert = ja_bert .to (device ).unsqueeze (0 )
297
+ x_tst_lengths = torch .LongTensor ([phones .size (0 )]).to (device )
298
+ del phones
299
+ speakers = torch .LongTensor ([speaker_id ]).to (device )
300
+ audio = (
301
+ self .model .infer (
302
+ x_tst ,
303
+ x_tst_lengths ,
304
+ speakers ,
305
+ tones ,
306
+ lang_ids ,
307
+ bert ,
308
+ ja_bert ,
309
+ sdp_ratio = sdp_ratio ,
310
+ noise_scale = noise_scale ,
311
+ noise_scale_w = noise_scale_w ,
312
+ length_scale = 1.0 / speed ,
313
+ )[0 ][0 , 0 ]
314
+ .data .cpu ()
315
+ .float ()
316
+ .numpy ()
317
+ )
318
+ del x_tst , tones , lang_ids , bert , ja_bert , x_tst_lengths , speakers
319
+
320
+ audio_segments = []
321
+ audio_segments += audio .reshape (- 1 ).tolist ()
322
+ audio_segments += [0 ] * int (
323
+ (self .hps .data .sampling_rate * 0.05 ) / speed
324
+ )
325
+ audio_segments = np .array (audio_segments ).astype (np .float32 )
326
+
327
+ yield audio_segments
328
+
329
+ torch .cuda .empty_cache ()
330
+
331
+ def tts_to_base64 (
332
+ self ,
333
+ text ,
334
+ speaker_id ,
335
+ sdp_ratio = 0.2 ,
336
+ noise_scale = 0.6 ,
337
+ noise_scale_w = 0.8 ,
338
+ speed = 1.0 ,
339
+ pbar = None ,
340
+ format = None ,
341
+ position = None ,
342
+ quiet = False ,
343
+ ):
344
+ audio_list = []
345
+ for audio in self .tts_iter (
346
+ text ,
347
+ speaker_id ,
348
+ sdp_ratio ,
349
+ noise_scale ,
350
+ noise_scale_w ,
351
+ speed ,
352
+ pbar ,
353
+ position ,
354
+ quiet ,
355
+ ):
356
+ audio_list .append (audio )
357
+
358
+ audio = np .concatenate (audio_list )
359
+
360
+ with io .BytesIO () as wav_buffer :
361
+ soundfile .write (
362
+ wav_buffer , audio , self .hps .data .sampling_rate , format = "WAV"
363
+ )
364
+ wav_buffer .seek (0 )
365
+ wav_bytes = wav_buffer .read ()
366
+
367
+ wav_base64 = base64 .b64encode (wav_bytes ).decode ("utf-8" )
368
+ end_time = datetime .now ()
369
+ elapsed_time = end_time - start_time
210
370
371
+ return jsonable_encoder (
372
+ {"audioContent" : wav_base64 , "time_taken" : elapsed_time }
373
+ )
0 commit comments