diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 8e9d6bd382..43bcefd4da 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -209,6 +209,8 @@ def __init__(self, config: Coqpit): self.decoder_checkpoint = self.args.decoder_checkpoint # TODO: check if this is even needed self.models_dir = config.model_dir self.gpt_batch_size = self.args.gpt_batch_size + self._stream_text_holder = [] + self._stream_generator = None self.tokenizer = VoiceBpeTokenizer() self.gpt = None @@ -632,64 +634,140 @@ def inference_stream( length_scale = 1.0 / max(speed, 0.05) gpt_cond_latent = gpt_cond_latent.to(self.device) speaker_embedding = speaker_embedding.to(self.device) - if enable_text_splitting: - text = split_sentence(text, language, self.tokenizer.char_limits[language]) - else: - text = [text] + text_streaming = (text is None) - for sent in text: - sent = sent.strip().lower() - text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device) + while True: + if text_streaming: + yield None + if len(self._stream_text_holder) == 0: + return + text, enable_text_splitting = self._stream_text_holder - assert ( - text_tokens.shape[-1] < self.args.gpt_max_text_tokens - ), " ❗ XTTS can only generate text with a maximum of 400 tokens." + if enable_text_splitting: + text = split_sentence(text, language, self.tokenizer.char_limits[language]) + else: + text = [text] - fake_inputs = self.gpt.compute_embeddings( - gpt_cond_latent.to(self.device), - text_tokens, - ) - gpt_generator = self.gpt.get_generator( - fake_inputs=fake_inputs, - top_k=top_k, - top_p=top_p, - temperature=temperature, - do_sample=do_sample, - num_beams=1, - num_return_sequences=1, - length_penalty=float(length_penalty), - repetition_penalty=float(repetition_penalty), - output_attentions=False, - output_hidden_states=True, - **hf_generate_kwargs, - ) + for sent in text: + sent = sent.strip().lower() + text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device) + + assert ( + text_tokens.shape[-1] < self.args.gpt_max_text_tokens + ), " ❗ XTTS can only generate text with a maximum of 400 tokens." + + fake_inputs = self.gpt.compute_embeddings( + gpt_cond_latent.to(self.device), + text_tokens, + ) + gpt_generator = self.gpt.get_generator( + fake_inputs=fake_inputs, + top_k=top_k, + top_p=top_p, + temperature=temperature, + do_sample=do_sample, + num_beams=1, + num_return_sequences=1, + length_penalty=float(length_penalty), + repetition_penalty=float(repetition_penalty), + output_attentions=False, + output_hidden_states=True, + **hf_generate_kwargs, + ) + + last_tokens = [] + all_latents = [] + wav_gen_prev = None + wav_overlap = None + is_end = False + + while not is_end: + try: + x, latent = next(gpt_generator) + last_tokens += [x] + all_latents += [latent] + except StopIteration: + is_end = True + + if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size): + gpt_latents = torch.cat(all_latents, dim=0)[None, :] + if length_scale != 1.0: + gpt_latents = F.interpolate( + gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear" + ).transpose(1, 2) + wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) + wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks( + wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len + ) + last_tokens = [] + yield wav_chunk + + if not text_streaming: + return + + def inference_stream_text( + self, + language, + gpt_cond_latent, + speaker_embedding, + # Streaming + stream_chunk_size=20, + overlap_wav_len=1024, + # GPT inference + temperature=0.75, + length_penalty=1.0, + repetition_penalty=10.0, + top_k=50, + top_p=0.85, + do_sample=True, + speed=1.0, + **hf_generate_kwargs, + ): + if self._stream_generator is not None: + raise Exception('Inference text-streaming already in progress. ' + 'Did you forget to call inference_finalize_text?') + + # Arguments `text` and `enable_text_splitting` given through holder + self._stream_text_holder = [None, None] + self._stream_generator = self.inference_stream( + None, + language, + gpt_cond_latent, + speaker_embedding, + stream_chunk_size=stream_chunk_size, + overlap_wav_len=overlap_wav_len, + temperature=temperature, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + top_k=top_k, + top_p=top_p, + do_sample=do_sample, + speed=speed, + **hf_generate_kwargs, + ) - last_tokens = [] - all_latents = [] - wav_gen_prev = None - wav_overlap = None - is_end = False - - while not is_end: - try: - x, latent = next(gpt_generator) - last_tokens += [x] - all_latents += [latent] - except StopIteration: - is_end = True - - if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size): - gpt_latents = torch.cat(all_latents, dim=0)[None, :] - if length_scale != 1.0: - gpt_latents = F.interpolate( - gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear" - ).transpose(1, 2) - wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) - wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks( - wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len - ) - last_tokens = [] - yield wav_chunk + # Start the generator and return it + _ = next(self._stream_generator) + return self._stream_generator + + def inference_add_text(self, text: str, enable_text_splitting=False): + if self._stream_generator is None: + raise Exception('Inference text-streaming not started. ' + 'Please call inference_stream_text first') + self._stream_text_holder[0] = text + self._stream_text_holder[1] = enable_text_splitting + + def inference_finalize_text(self): + if self._stream_generator is None: + raise Exception('Inference text-streaming was not started ' + '(start with inference_stream_text)') + # Finalize and reset the generator + self._stream_text_holder.clear() + try: + _ = next(self._stream_generator) + except StopIteration: + pass + self._stream_generator = None def forward(self): raise NotImplementedError( diff --git a/docs/source/models/xtts.md b/docs/source/models/xtts.md index b979d04f6e..da75e4ae48 100644 --- a/docs/source/models/xtts.md +++ b/docs/source/models/xtts.md @@ -220,7 +220,7 @@ torchaudio.save("xtts.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000) ``` -##### Streaming manually +##### Streaming inference Here the goal is to stream the audio as it is being generated. This is useful for real-time applications. Streaming inference is typically slower than regular inference, but it allows to get a first chunk of audio faster. @@ -253,16 +253,50 @@ chunks = model.inference_stream( speaker_embedding ) -wav_chuncks = [] +wav_chunks = [] for i, chunk in enumerate(chunks): if i == 0: print(f"Time to first chunck: {time.time() - t0}") print(f"Received chunk {i} of audio length {chunk.shape[-1]}") - wav_chuncks.append(chunk) -wav = torch.cat(wav_chuncks, dim=0) + wav_chunks.append(chunk) +wav = torch.cat(wav_chunks, dim=0) torchaudio.save("xtts_streaming.wav", wav.squeeze().unsqueeze(0).cpu(), 24000) ``` +If you also need to do text-streaming you can use `inference_stream_text`, like so: + +```python +# ...same setup as before + +def text_streaming_generator(): + yield "It took me quite a long time to develop a voice and now that I have it I am not going to be silent." + yield "Having discovered not just one, but many voices, I will champion each." + +print("Inference with text streaming...") + +text_gen = text_streaming_generator() +inf_gen = model.inference_stream_text( + "en", + gpt_cond_latent, + speaker_embedding +) + +wav_chunks = [] +for text in text_gen: + # Add text progressively + model.inference_add_text(text, enable_text_splitting=True) + for chunk in enumerate(inf_gen): + if chunk is None: + break # all chunks generated for the current text + print(f"Received chunk {len(wav_chunks)} of audio length {chunk.shape[-1]}") + wav_chunks.append(chunk) + +# Call finalize to discard the inference generator +model.inference_finalize_text() + +wav = torch.cat(wav_chunks, dim=0) +torchaudio.save("xtts_streaming_text.wav", wav.squeeze().unsqueeze(0).cpu(), 24000) +``` ### Training