Skip to content

Commit 66cc9c4

Browse files
committed
Implement tts_iter
myshell-ai#88
1 parent 66de122 commit 66cc9c4

File tree

1 file changed

+211
-48
lines changed

1 file changed

+211
-48
lines changed

melo/api.py

Lines changed: 211 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,22 @@
2727

2828
start_time = datetime.now()
2929

30+
3031
class TTS(nn.Module):
31-
def __init__(self,
32-
language,
33-
device='auto',
34-
use_hf=True,
35-
config_path=None,
36-
ckpt_path=None):
32+
def __init__(
33+
self, language, device="auto", use_hf=True, config_path=None, ckpt_path=None
34+
):
3735
super().__init__()
38-
if device == 'auto':
39-
device = 'cpu'
40-
if torch.cuda.is_available(): device = 'cuda'
41-
if torch.backends.mps.is_available(): device = 'mps'
42-
if 'cuda' in device:
36+
if device == "auto":
37+
device = "cpu"
38+
if torch.cuda.is_available():
39+
device = "cuda"
40+
if torch.backends.mps.is_available():
41+
device = "mps"
42+
if "cuda" in device:
4343
assert torch.cuda.is_available()
4444

45-
# config_path =
45+
# config_path =
4646
hps = load_or_download_config(language, use_hf=use_hf, config_path=config_path)
4747

4848
num_languages = hps.num_languages
@@ -64,16 +64,20 @@ def __init__(self,
6464
self.symbol_to_id = {s: i for i, s in enumerate(symbols)}
6565
self.hps = hps
6666
self.device = device
67-
67+
6868
# load state_dict
69-
checkpoint_dict = load_or_download_model(language, device, use_hf=use_hf, ckpt_path=ckpt_path)
70-
self.model.load_state_dict(checkpoint_dict['model'], strict=True)
71-
72-
language = language.split('_')[0]
73-
self.language = 'ZH_MIX_EN' if language == 'ZH' else language # we support a ZH_MIX_EN model
69+
checkpoint_dict = load_or_download_model(
70+
language, device, use_hf=use_hf, ckpt_path=ckpt_path
71+
)
72+
self.model.load_state_dict(checkpoint_dict["model"], strict=True)
73+
74+
language = language.split("_")[0]
75+
self.language = (
76+
"ZH_MIX_EN" if language == "ZH" else language
77+
) # we support a ZH_MIX_EN model
7478

7579
@staticmethod
76-
def audio_numpy_concat(segment_data_list, sr, speed=1.):
80+
def audio_numpy_concat(segment_data_list, sr, speed=1.0):
7781
audio_segments = []
7882
for segment_data in segment_data_list:
7983
audio_segments += segment_data.reshape(-1).tolist()
@@ -86,11 +90,24 @@ def split_sentences_into_pieces(text, language, quiet=False):
8690
texts = split_sentence(text, language_str=language)
8791
if not quiet:
8892
print(" > Text split to sentences.")
89-
print('\n'.join(texts))
93+
print("\n".join(texts))
9094
print(" > ===========================")
9195
return texts
9296

93-
def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_scale=0.6, noise_scale_w=0.8, speed=1.0, pbar=None, format=None, position=None, quiet=False,):
97+
def tts_to_file(
98+
self,
99+
text,
100+
speaker_id,
101+
output_path=None,
102+
sdp_ratio=0.2,
103+
noise_scale=0.6,
104+
noise_scale_w=0.8,
105+
speed=1.0,
106+
pbar=None,
107+
format=None,
108+
position=None,
109+
quiet=False,
110+
):
94111
language = self.language
95112
texts = self.split_sentences_into_pieces(text, language, quiet)
96113
audio_list = []
@@ -104,10 +121,12 @@ def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_s
104121
else:
105122
tx = tqdm(texts)
106123
for t in tx:
107-
if language in ['EN', 'ZH_MIX_EN']:
108-
t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t)
124+
if language in ["EN", "ZH_MIX_EN"]:
125+
t = re.sub(r"([a-z])([A-Z])", r"\1 \2", t)
109126
device = self.device
110-
bert, ja_bert, phones, tones, lang_ids = utils.get_text_for_tts_infer(t, language, self.hps, device, self.symbol_to_id)
127+
bert, ja_bert, phones, tones, lang_ids = utils.get_text_for_tts_infer(
128+
t, language, self.hps, device, self.symbol_to_id
129+
)
111130
with torch.no_grad():
112131
x_tst = phones.to(device).unsqueeze(0)
113132
tones = tones.to(device).unsqueeze(0)
@@ -117,7 +136,8 @@ def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_s
117136
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
118137
del phones
119138
speakers = torch.LongTensor([speaker_id]).to(device)
120-
audio = self.model.infer(
139+
audio = (
140+
self.model.infer(
121141
x_tst,
122142
x_tst_lengths,
123143
speakers,
@@ -128,26 +148,43 @@ def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_s
128148
sdp_ratio=sdp_ratio,
129149
noise_scale=noise_scale,
130150
noise_scale_w=noise_scale_w,
131-
length_scale=1. / speed,
132-
)[0][0, 0].data.cpu().float().numpy()
151+
length_scale=1.0 / speed,
152+
)[0][0, 0]
153+
.data.cpu()
154+
.float()
155+
.numpy()
156+
)
133157
del x_tst, tones, lang_ids, bert, ja_bert, x_tst_lengths, speakers
134-
#
158+
#
135159
audio_list.append(audio)
136160
torch.cuda.empty_cache()
137-
audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed)
161+
audio = self.audio_numpy_concat(
162+
audio_list, sr=self.hps.data.sampling_rate, speed=speed
163+
)
138164

139165
if output_path is None:
140166
return audio
141167
else:
142168
if format:
143-
soundfile.write(output_path, audio, self.hps.data.sampling_rate, format=format)
169+
soundfile.write(
170+
output_path, audio, self.hps.data.sampling_rate, format=format
171+
)
144172
else:
145173
soundfile.write(output_path, audio, self.hps.data.sampling_rate)
146174

147-
148-
149-
150-
def tts_to_base64(self, text, speaker_id, sdp_ratio=0.2, noise_scale=0.6, noise_scale_w=0.8, speed=1.0, pbar=None, format=None, position=None, quiet=False,):
175+
def old_tts_to_base64(
176+
self,
177+
text,
178+
speaker_id,
179+
sdp_ratio=0.2,
180+
noise_scale=0.6,
181+
noise_scale_w=0.8,
182+
speed=1.0,
183+
pbar=None,
184+
format=None,
185+
position=None,
186+
quiet=False,
187+
):
151188
language = self.language
152189
texts = self.split_sentences_into_pieces(text, language, quiet)
153190
audio_list = []
@@ -161,10 +198,12 @@ def tts_to_base64(self, text, speaker_id, sdp_ratio=0.2, noise_scale=0.6, noise_
161198
else:
162199
tx = tqdm(texts)
163200
for t in tx:
164-
if language in ['EN', 'ZH_MIX_EN']:
165-
t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t)
201+
if language in ["EN", "ZH_MIX_EN"]:
202+
t = re.sub(r"([a-z])([A-Z])", r"\1 \2", t)
166203
device = self.device
167-
bert, ja_bert, phones, tones, lang_ids = utils.get_text_for_tts_infer(t, language, self.hps, device, self.symbol_to_id)
204+
bert, ja_bert, phones, tones, lang_ids = utils.get_text_for_tts_infer(
205+
t, language, self.hps, device, self.symbol_to_id
206+
)
168207
with torch.no_grad():
169208
x_tst = phones.to(device).unsqueeze(0)
170209
tones = tones.to(device).unsqueeze(0)
@@ -174,7 +213,8 @@ def tts_to_base64(self, text, speaker_id, sdp_ratio=0.2, noise_scale=0.6, noise_
174213
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
175214
del phones
176215
speakers = torch.LongTensor([speaker_id]).to(device)
177-
audio = self.model.infer(
216+
audio = (
217+
self.model.infer(
178218
x_tst,
179219
x_tst_lengths,
180220
speakers,
@@ -185,26 +225,149 @@ def tts_to_base64(self, text, speaker_id, sdp_ratio=0.2, noise_scale=0.6, noise_
185225
sdp_ratio=sdp_ratio,
186226
noise_scale=noise_scale,
187227
noise_scale_w=noise_scale_w,
188-
length_scale=1. / speed,
189-
)[0][0, 0].data.cpu().float().numpy()
228+
length_scale=1.0 / speed,
229+
)[0][0, 0]
230+
.data.cpu()
231+
.float()
232+
.numpy()
233+
)
190234
del x_tst, tones, lang_ids, bert, ja_bert, x_tst_lengths, speakers
191-
#
235+
#
192236
audio_list.append(audio)
193237
torch.cuda.empty_cache()
194-
audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed)
238+
audio = self.audio_numpy_concat(
239+
audio_list, sr=self.hps.data.sampling_rate, speed=speed
240+
)
195241

196242
with io.BytesIO() as wav_buffer:
197-
soundfile.write(wav_buffer, audio, self.hps.data.sampling_rate, format="WAV")
243+
soundfile.write(
244+
wav_buffer, audio, self.hps.data.sampling_rate, format="WAV"
245+
)
198246
wav_buffer.seek(0)
199247
wav_bytes = wav_buffer.read()
200248

201-
202249
wav_base64 = base64.b64encode(wav_bytes).decode("utf-8")
203250
end_time = datetime.now()
204251
elapsed_time = end_time - start_time
205252

206-
return jsonable_encoder({
207-
"audioContent": wav_base64,
208-
"time_taken": elapsed_time
209-
})
253+
return jsonable_encoder(
254+
{"audioContent": wav_base64, "time_taken": elapsed_time}
255+
)
256+
257+
def tts_iter(
258+
self,
259+
text,
260+
speaker_id,
261+
sdp_ratio=0.2,
262+
noise_scale=0.6,
263+
noise_scale_w=0.8,
264+
speed=1.0,
265+
pbar=None,
266+
position=None,
267+
quiet=False,
268+
):
269+
"""
270+
https://github.com/myshell-ai/MeloTTS/pull/88/files
271+
"""
272+
language = self.language
273+
texts = self.split_sentences_into_pieces(text, language, quiet)
274+
275+
if pbar:
276+
tx = pbar(texts)
277+
else:
278+
if position:
279+
tx = tqdm(texts, position=position)
280+
elif quiet:
281+
tx = texts
282+
else:
283+
tx = tqdm(texts)
284+
for t in tx:
285+
if language in ["EN", "ZH_MIX_EN"]:
286+
t = re.sub(r"([a-z])([A-Z])", r"\1 \2", t)
287+
device = self.device
288+
bert, ja_bert, phones, tones, lang_ids = utils.get_text_for_tts_infer(
289+
t, language, self.hps, device, self.symbol_to_id
290+
)
291+
with torch.no_grad():
292+
x_tst = phones.to(device).unsqueeze(0)
293+
tones = tones.to(device).unsqueeze(0)
294+
lang_ids = lang_ids.to(device).unsqueeze(0)
295+
bert = bert.to(device).unsqueeze(0)
296+
ja_bert = ja_bert.to(device).unsqueeze(0)
297+
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
298+
del phones
299+
speakers = torch.LongTensor([speaker_id]).to(device)
300+
audio = (
301+
self.model.infer(
302+
x_tst,
303+
x_tst_lengths,
304+
speakers,
305+
tones,
306+
lang_ids,
307+
bert,
308+
ja_bert,
309+
sdp_ratio=sdp_ratio,
310+
noise_scale=noise_scale,
311+
noise_scale_w=noise_scale_w,
312+
length_scale=1.0 / speed,
313+
)[0][0, 0]
314+
.data.cpu()
315+
.float()
316+
.numpy()
317+
)
318+
del x_tst, tones, lang_ids, bert, ja_bert, x_tst_lengths, speakers
319+
320+
audio_segments = []
321+
audio_segments += audio.reshape(-1).tolist()
322+
audio_segments += [0] * int(
323+
(self.hps.data.sampling_rate * 0.05) / speed
324+
)
325+
audio_segments = np.array(audio_segments).astype(np.float32)
326+
327+
yield audio_segments
328+
329+
torch.cuda.empty_cache()
330+
331+
def tts_to_base64(
332+
self,
333+
text,
334+
speaker_id,
335+
sdp_ratio=0.2,
336+
noise_scale=0.6,
337+
noise_scale_w=0.8,
338+
speed=1.0,
339+
pbar=None,
340+
format=None,
341+
position=None,
342+
quiet=False,
343+
):
344+
audio_list = []
345+
for audio in self.tts_iter(
346+
text,
347+
speaker_id,
348+
sdp_ratio,
349+
noise_scale,
350+
noise_scale_w,
351+
speed,
352+
pbar,
353+
position,
354+
quiet,
355+
):
356+
audio_list.append(audio)
357+
358+
audio = np.concatenate(audio_list)
359+
360+
with io.BytesIO() as wav_buffer:
361+
soundfile.write(
362+
wav_buffer, audio, self.hps.data.sampling_rate, format="WAV"
363+
)
364+
wav_buffer.seek(0)
365+
wav_bytes = wav_buffer.read()
366+
367+
wav_base64 = base64.b64encode(wav_bytes).decode("utf-8")
368+
end_time = datetime.now()
369+
elapsed_time = end_time - start_time
210370

371+
return jsonable_encoder(
372+
{"audioContent": wav_base64, "time_taken": elapsed_time}
373+
)

0 commit comments

Comments
 (0)