diff --git a/TTS/.models.json b/TTS/.models.json index b33e4fd323..5b35d4e267 100644 --- a/TTS/.models.json +++ b/TTS/.models.json @@ -10,34 +10,22 @@ "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json", "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/hash.md5" ], + "model_hash": "ae9e4b39e095fd5728fe7f7931eccoqui", "default_vocoder": null, "commit": "480a6cdf7", "license": "CPML", "contact": "info@coqui.ai", "tos_required": true }, - "xtts_v1": { - "description": "XTTS-v1 by Coqui with 13 languages and cross-language voice cloning.", - "hf_url": [ - "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/hifigan/model.pth", - "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/hifigan/config.json", - "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/hifigan/vocab.json" - ], - "default_vocoder": null, - "commit": "e5140314", - "license": "CPML", - "contact": "info@coqui.ai", - "tos_required": true - }, "xtts_v1.1": { "description": "XTTS-v1.1 by Coqui with 14 languages, cross-language voice cloning and reference leak fixed.", "hf_url": [ - "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/model.pth", - "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/config.json", - "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/vocab.json", - "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/hash.md5" + "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/model.pth", + "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/config.json", + "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/vocab.json", + "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/hash.md5" ], - "model_hash": "ae9e4b39e095fd5728fe7f7931ec66ad", + "model_hash": "7c62beaf58d39b729de287330dc254e7b515677416839b649a50e7cf74c3df59", "default_vocoder": null, "commit": "82910a63", "license": "CPML", diff --git a/TTS/tts/layers/xtts/trainer/dataset.py b/TTS/tts/layers/xtts/trainer/dataset.py index 5d8b2ae66b..8cb90ad0f8 100644 --- a/TTS/tts/layers/xtts/trainer/dataset.py +++ b/TTS/tts/layers/xtts/trainer/dataset.py @@ -2,13 +2,10 @@ import random import sys -import numpy as np import torch import torch.nn.functional as F import torch.utils.data -import torchaudio -from torchaudio.backend.soundfile_backend import load as torchaudio_soundfile_load -from torchaudio.backend.sox_io_backend import load as torchaudio_sox_load +from TTS.tts.models.xtts import load_audio torch.set_num_threads(1) @@ -50,31 +47,6 @@ def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate, return rel_clip, rel_clip.shape[-1], cond_idxs -def load_audio(audiopath, sampling_rate): - # better load setting following: https://github.com/faroit/python_audio_loading_benchmark - if audiopath[-4:] == ".mp3": - # it uses torchaudio with sox backend to load mp3 - audio, lsr = torchaudio_sox_load(audiopath) - else: - # it uses torchaudio soundfile backend to load all the others data type - audio, lsr = torchaudio_soundfile_load(audiopath) - - # stereo to mono if needed - if audio.size(0) != 1: - audio = torch.mean(audio, dim=0, keepdim=True) - - if lsr != sampling_rate: - audio = torchaudio.functional.resample(audio, lsr, sampling_rate) - - # Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk. - # '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds. - if torch.any(audio > 10) or not torch.any(audio < 0): - print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}") - # clip audio invalid values - audio.clip_(-1, 1) - return audio - - class XTTSDataset(torch.utils.data.Dataset): def __init__(self, config, samples, tokenizer, sample_rate, is_eval=False): self.config = config diff --git a/TTS/tts/layers/xtts/trainer/gpt_trainer.py b/TTS/tts/layers/xtts/trainer/gpt_trainer.py index ef32a4abd0..005b30bede 100644 --- a/TTS/tts/layers/xtts/trainer/gpt_trainer.py +++ b/TTS/tts/layers/xtts/trainer/gpt_trainer.py @@ -238,7 +238,6 @@ def test_run(self, assets) -> Tuple[Dict, Dict]: # pylint: disable=W0613 s_info["speaker_wav"], s_info["language"], gpt_cond_len=3, - decoder="ne_hifigan", )["wav"] test_audios["{}-audio".format(idx)] = wav diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index af94675be9..4ab0027072 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -67,6 +67,31 @@ def wav_to_mel_cloning( return mel +def load_audio(audiopath, sampling_rate): + # better load setting following: https://github.com/faroit/python_audio_loading_benchmark + if audiopath[-4:] == ".mp3": + # it uses torchaudio with sox backend to load mp3 + audio, lsr = torchaudio.backend.sox_io_backend.load(audiopath) + else: + # it uses torchaudio soundfile backend to load all the others data type + audio, lsr = torchaudio.backend.soundfile_backend.load(audiopath) + + # stereo to mono if needed + if audio.size(0) != 1: + audio = torch.mean(audio, dim=0, keepdim=True) + + if lsr != sampling_rate: + audio = torchaudio.functional.resample(audio, lsr, sampling_rate) + + # Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk. + # '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds. + if torch.any(audio > 10) or not torch.any(audio < 0): + print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}") + # clip audio invalid values + audio.clip_(-1, 1) + return audio + + def pad_or_truncate(t, length): """ Ensure a given tensor t has a specified sequence length by either padding it with zeros or clipping it. @@ -86,78 +111,6 @@ def pad_or_truncate(t, length): return tp -def load_discrete_vocoder_diffuser( - trained_diffusion_steps=4000, - desired_diffusion_steps=200, - cond_free=True, - cond_free_k=1, - sampler="ddim", -): - """ - Load a GaussianDiffusion instance configured for use as a decoder. - - Args: - trained_diffusion_steps (int): The number of diffusion steps used during training. - desired_diffusion_steps (int): The number of diffusion steps to use during inference. - cond_free (bool): Whether to use a conditioning-free model. - cond_free_k (int): The number of samples to use for conditioning-free models. - sampler (str): The name of the sampler to use. - - Returns: - A SpacedDiffusion instance configured with the given parameters. - """ - return SpacedDiffusion( - use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), - model_mean_type="epsilon", - model_var_type="learned_range", - loss_type="mse", - betas=get_named_beta_schedule("linear", trained_diffusion_steps), - conditioning_free=cond_free, - conditioning_free_k=cond_free_k, - sampler=sampler, - ) - - -def do_spectrogram_diffusion( - diffusion_model, - diffuser, - latents, - conditioning_latents, - temperature=1, -): - """ - Generate a mel-spectrogram using a diffusion model and a diffuser. - - Args: - diffusion_model (nn.Module): A diffusion model that converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. - diffuser (Diffuser): A diffuser that generates a mel-spectrogram from noise. - latents (torch.Tensor): A tensor of shape (batch_size, seq_len, code_size) containing the input spectrogram codes. - conditioning_latents (torch.Tensor): A tensor of shape (batch_size, code_size) containing the conditioning codes. - temperature (float, optional): The temperature of the noise used by the diffuser. Defaults to 1. - - Returns: - torch.Tensor: A tensor of shape (batch_size, mel_channels, mel_seq_len) containing the generated mel-spectrogram. - """ - with torch.no_grad(): - output_seq_len = ( - latents.shape[1] * 4 * 24000 // 22050 - ) # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. - output_shape = (latents.shape[0], 100, output_seq_len) - precomputed_embeddings = diffusion_model.timestep_independent( - latents, conditioning_latents, output_seq_len, False - ) - - noise = torch.randn(output_shape, device=latents.device) * temperature - mel = diffuser.sample_loop( - diffusion_model, - output_shape, - noise=noise, - model_kwargs={"precomputed_aligned_embeddings": precomputed_embeddings}, - progress=False, - ) - return denormalize_tacotron_mel(mel)[:, :, :output_seq_len] - - @dataclass class XttsAudioConfig(Coqpit): """ @@ -336,7 +289,7 @@ def get_gpt_cond_latents(self, audio, sr, length: int = 3): """Compute the conditioning latents for the GPT model from the given audio. Args: - audio_path (str): Path to the audio file. + audio (tensor): audio tensor. sr (int): Sample rate of the audio. length (int): Length of the audio in seconds. Defaults to 3. """ @@ -404,20 +357,41 @@ def get_conditioning_latents( max_ref_length=10, librosa_trim_db=None, sound_norm_refs=False, + load_sr=24000, ): + # deal with multiples references + if not isinstance(audio_path, list): + audio_paths = [audio_path] + else: + audio_paths = audio_path + + speaker_embeddings = [] + audios = [] speaker_embedding = None + for file_path in audio_paths: + # load the audio in 24khz to avoid issued with multiple sr references + audio = load_audio(file_path, load_sr) + audio = audio[:, : load_sr * max_ref_length].to(self.device) + if audio.shape[0] > 1: + audio = audio.mean(0, keepdim=True) + if sound_norm_refs: + audio = (audio / torch.abs(audio).max()) * 0.75 + if librosa_trim_db is not None: + audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0] + + speaker_embedding = self.get_speaker_embedding(audio, load_sr) + speaker_embeddings.append(speaker_embedding) + + audios.append(audio) + + # use a merge of all references for gpt cond latents + full_audio = torch.cat(audios, dim=-1) + gpt_cond_latents = self.get_gpt_cond_latents(full_audio, load_sr, length=gpt_cond_len) # [1, 1024, T] + + if speaker_embeddings: + speaker_embedding = torch.stack(speaker_embeddings) + speaker_embedding = speaker_embedding.mean(dim=0) - audio, sr = torchaudio.load(audio_path) - audio = audio[:, : sr * max_ref_length].to(self.device) - if audio.shape[0] > 1: - audio = audio.mean(0, keepdim=True) - if sound_norm_refs: - audio = (audio / torch.abs(audio).max()) * 0.75 - if librosa_trim_db is not None: - audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0] - - speaker_embedding = self.get_speaker_embedding(audio, sr) - gpt_cond_latents = self.get_gpt_cond_latents(audio, sr, length=gpt_cond_len) # [1, 1024, T] return gpt_cond_latents, speaker_embedding def synthesize(self, text, config, speaker_wav, language, **kwargs): @@ -426,7 +400,7 @@ def synthesize(self, text, config, speaker_wav, language, **kwargs): Args: text (str): Input text. config (XttsConfig): Config with inference parameters. - speaker_wav (str): Path to the speaker audio file for cloning. + speaker_wav (list): List of paths to the speaker audio files to be used for cloning. language (str): Language ID of the speaker. **kwargs: Inference settings. See `inference()`. @@ -436,11 +410,6 @@ def synthesize(self, text, config, speaker_wav, language, **kwargs): as latents used at inference. """ - - # Make the synthesizer happy 🥳 - if isinstance(speaker_wav, list): - speaker_wav = speaker_wav[0] - return self.inference_with_config(text, config, ref_audio_path=speaker_wav, language=language, **kwargs) def inference_with_config(self, text, config, ref_audio_path, language, **kwargs): @@ -522,27 +491,6 @@ def full_inference( gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used else the first `gpt_cond_len` secs is used. Defaults to 6 seconds. - decoder_iterations: (int) Number of diffusion steps to perform. [0,4000]. More steps means the network has - more chances to iteratively refine the output, which should theoretically mean a higher quality output. - Generally a value above 250 is not noticeably better, however. Defaults to 100. - - cond_free: (bool) Whether or not to perform conditioning-free diffusion. Conditioning-free diffusion - performs two forward passes for each diffusion step: one with the outputs of the autoregressive model - and one with no conditioning priors. The output of the two is blended according to the cond_free_k - value below. Conditioning-free diffusion is the real deal, and dramatically improves realism. - Defaults to True. - - cond_free_k: (float) Knob that determines how to balance the conditioning free signal with the - conditioning-present signal. [0,inf]. As cond_free_k increases, the output becomes dominated by the - conditioning-free signal. Defaults to 2.0. - - diffusion_temperature: (float) Controls the variance of the noise fed into the diffusion model. [0,1]. - Values at 0 re the "mean" prediction of the diffusion network and will sound bland and smeared. - Defaults to 1.0. - - decoder: (str) Selects the decoder to use between ("hifigan", "diffusion") - Defaults to hifigan - hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation here: https://huggingface.co/docs/transformers/internal/generation_utils @@ -569,12 +517,6 @@ def full_inference( top_k=top_k, top_p=top_p, do_sample=do_sample, - decoder_iterations=decoder_iterations, - cond_free=cond_free, - cond_free_k=cond_free_k, - diffusion_temperature=diffusion_temperature, - decoder_sampler=decoder_sampler, - decoder=decoder, **hf_generate_kwargs, ) @@ -592,13 +534,6 @@ def inference( top_k=50, top_p=0.85, do_sample=True, - # Decoder inference - decoder_iterations=100, - cond_free=True, - cond_free_k=2, - diffusion_temperature=1.0, - decoder_sampler="ddim", - decoder="hifigan", num_beams=1, **hf_generate_kwargs, ): @@ -693,8 +628,6 @@ def inference_stream( top_k=50, top_p=0.85, do_sample=True, - # Decoder inference - decoder="hifigan", **hf_generate_kwargs, ): text = text.strip().lower() diff --git a/docs/source/models/xtts.md b/docs/source/models/xtts.md index 1d034aeadf..8167a1d1a9 100644 --- a/docs/source/models/xtts.md +++ b/docs/source/models/xtts.md @@ -39,6 +39,7 @@ You can also mail us at info@coqui.ai. ### Inference #### 🐸TTS API +##### Single reference ```python from TTS.api import TTS tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=True) @@ -46,12 +47,25 @@ tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=True) # generate speech by cloning a voice using default settings tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", file_path="output.wav", - speaker_wav="/path/to/target/speaker.wav", + speaker_wav=["/path/to/target/speaker.wav"], + language="en") +``` + +##### Multiple references +```python +from TTS.api import TTS +tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=True) + +# generate speech by cloning a voice using default settings +tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + file_path="output.wav", + speaker_wav=["/path/to/target/speaker.wav", "/path/to/target/speaker_2.wav", "/path/to/target/speaker_3.wav"], language="en") ``` #### 🐸TTS Command line +##### Single reference ```console tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 \ --text "Bugün okula gitmek istemiyorum." \ @@ -60,6 +74,25 @@ tts.tts_to_file(text="It took me quite a long time to develop a voice, and now t --use_cuda true ``` +##### Multiple references +```console + tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 \ + --text "Bugün okula gitmek istemiyorum." \ + --speaker_wav /path/to/target/speaker.wav /path/to/target/speaker_2.wav /path/to/target/speaker_3.wav \ + --language_idx tr \ + --use_cuda true +``` +or for all wav files in a directory you can use: + +```console + tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 \ + --text "Bugün okula gitmek istemiyorum." \ + --speaker_wav /path/to/target/*.wav \ + --language_idx tr \ + --use_cuda true +``` + + #### model directly If you want to be able to run with `use_deepspeed=True` and enjoy the speedup, you need to install deepspeed first. @@ -83,7 +116,7 @@ model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=Tru model.cuda() print("Computing speaker latents...") -gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path="reference.wav") +gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path=["reference.wav"]) print("Inference...") out = model.inference( @@ -120,7 +153,7 @@ model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=Tru model.cuda() print("Computing speaker latents...") -gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path="reference.wav") +gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path=["reference.wav"]) print("Inference...") t0 = time.time() @@ -177,7 +210,7 @@ model.load_checkpoint(config, checkpoint_path=XTTS_CHECKPOINT, vocab_path=TOKENI model.cuda() print("Computing speaker latents...") -gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path=SPEAKER_REFERENCE) +gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path=[SPEAKER_REFERENCE]) print("Inference...") out = model.inference( diff --git a/recipes/ljspeech/xtts_v1/train_gpt_xtts.py b/recipes/ljspeech/xtts_v1/train_gpt_xtts.py index 268a033535..7d8f4064c5 100644 --- a/recipes/ljspeech/xtts_v1/train_gpt_xtts.py +++ b/recipes/ljspeech/xtts_v1/train_gpt_xtts.py @@ -41,8 +41,8 @@ # DVAE files -DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/dvae.pth" -MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/mel_stats.pth" +DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/dvae.pth" +MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/mel_stats.pth" # Set the path to the downloaded files DVAE_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, DVAE_CHECKPOINT_LINK.split("/")[-1]) @@ -55,8 +55,8 @@ # Download XTTS v1.1 checkpoint if needed -TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/vocab.json" -XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/model.pth" +TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/vocab.json" +XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/model.pth" # XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning. TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, TOKENIZER_FILE_LINK.split("/")[-1]) # vocab.json file @@ -71,9 +71,9 @@ # Training sentences generations -SPEAKER_REFERENCE = ( +SPEAKER_REFERENCE = [ "./tests/data/ljspeech/wavs/LJ001-0002.wav" # speaker reference to be used in training test sentences -) +] LANGUAGE = config_dataset.language diff --git a/recipes/ljspeech/xtts_v2/train_gpt_xtts.py b/recipes/ljspeech/xtts_v2/train_gpt_xtts.py index d94204ca4d..fa42174982 100644 --- a/recipes/ljspeech/xtts_v2/train_gpt_xtts.py +++ b/recipes/ljspeech/xtts_v2/train_gpt_xtts.py @@ -40,6 +40,7 @@ os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True) +# ToDo: update DVAE checkpoint # DVAE files DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/dvae.pth" MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/mel_stats.pth" @@ -53,15 +54,14 @@ print(" > Downloading DVAE files!") ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True) -# ToDo: Update links for XTTS v2.0 # Download XTTS v2.0 checkpoint if needed -TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v2.0/vocab.json" -XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v2.0/model.pth" +TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json" +XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth" # XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning. -TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, TOKENIZER_FILE_LINK.split("/")[-1]) # vocab.json file -XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, XTTS_CHECKPOINT_LINK.split("/")[-1]) # model.pth file +TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(TOKENIZER_FILE_LINK)) # vocab.json file +XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CHECKPOINT_LINK)) # model.pth file # download XTTS v2.0 files if needed if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT): @@ -72,9 +72,9 @@ # Training sentences generations -SPEAKER_REFERENCE = ( +SPEAKER_REFERENCE = [ "./tests/data/ljspeech/wavs/LJ001-0002.wav" # speaker reference to be used in training test sentences -) +] LANGUAGE = config_dataset.language @@ -90,9 +90,9 @@ def main(): dvae_checkpoint=DVAE_CHECKPOINT, xtts_checkpoint=XTTS_CHECKPOINT, # checkpoint path of the model that you want to fine-tune tokenizer_file=TOKENIZER_FILE, - gpt_num_audio_tokens=8194, - gpt_start_audio_token=8192, - gpt_stop_audio_token=8193, + gpt_num_audio_tokens=1024, + gpt_start_audio_token=1025, + gpt_stop_audio_token=1026, gpt_use_masking_gt_prompt_approach=True, gpt_use_perceiver_resampler=True, ) diff --git a/tests/xtts_tests/test_xtts_gpt_train.py b/tests/xtts_tests/test_xtts_gpt_train.py index 03514daa3b..12c547d684 100644 --- a/tests/xtts_tests/test_xtts_gpt_train.py +++ b/tests/xtts_tests/test_xtts_gpt_train.py @@ -60,7 +60,7 @@ # Training sentences generations -SPEAKER_REFERENCE = "tests/data/ljspeech/wavs/LJ001-0002.wav" # speaker reference to be used in training test sentences +SPEAKER_REFERENCE = ["tests/data/ljspeech/wavs/LJ001-0002.wav"] # speaker reference to be used in training test sentences LANGUAGE = config_dataset.language diff --git a/tests/xtts_tests/test_xtts_v2-0_gpt_train.py b/tests/xtts_tests/test_xtts_v2-0_gpt_train.py index 8099503855..b19b7210d8 100644 --- a/tests/xtts_tests/test_xtts_v2-0_gpt_train.py +++ b/tests/xtts_tests/test_xtts_v2-0_gpt_train.py @@ -58,7 +58,7 @@ # Training sentences generations -SPEAKER_REFERENCE = "tests/data/ljspeech/wavs/LJ001-0002.wav" # speaker reference to be used in training test sentences +SPEAKER_REFERENCE = ["tests/data/ljspeech/wavs/LJ001-0002.wav"] # speaker reference to be used in training test sentences LANGUAGE = config_dataset.language @@ -87,7 +87,9 @@ gpt_use_masking_gt_prompt_approach=True, gpt_use_perceiver_resampler=True, ) + audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000) + config = GPTTrainerConfig( epochs=1, output_path=OUT_PATH, diff --git a/tests/zoo_tests/test_models.py b/tests/zoo_tests/test_models.py index 2f9399add8..79aef5cb14 100644 --- a/tests/zoo_tests/test_models.py +++ b/tests/zoo_tests/test_models.py @@ -101,7 +101,9 @@ def test_xtts_streaming(): from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.models.xtts import Xtts - speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav") + speaker_wav = [os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")] + speaker_wav_2 = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0002.wav") + speaker_wav.append(speaker_wav_2) model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1") config = XttsConfig() config.load_json(os.path.join(model_path, "config.json")) @@ -131,20 +133,21 @@ def test_xtts_v2(): """XTTS is too big to run on github actions. We need to test it locally""" output_path = os.path.join(get_tests_output_path(), "output.wav") speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav") + speaker_wav_2 = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0002.wav") use_gpu = torch.cuda.is_available() if use_gpu: run_cli( "yes | " f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 " f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True ' - f'--speaker_wav "{speaker_wav}" --language_idx "en"' + f'--speaker_wav "{speaker_wav}" "{speaker_wav_2}" "--language_idx "en"' ) else: run_cli( "yes | " f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 " f'--text "This is an example." --out_path "{output_path}" --progress_bar False ' - f'--speaker_wav "{speaker_wav}" --language_idx "en"' + f'--speaker_wav "{speaker_wav}" "{speaker_wav_2}" --language_idx "en"' ) @@ -153,7 +156,7 @@ def test_xtts_v2_streaming(): from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.models.xtts import Xtts - speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav") + speaker_wav = [os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")] model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v2") config = XttsConfig() config.load_json(os.path.join(model_path, "config.json"))