Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fixes and add support for multiples speaker references on XTTS inference #3149

Merged
merged 13 commits into from
Nov 7, 2023
30 changes: 1 addition & 29 deletions TTS/tts/layers/xtts/trainer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@
import random
import sys

import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.data
import torchaudio
from torchaudio.backend.soundfile_backend import load as torchaudio_soundfile_load
from torchaudio.backend.sox_io_backend import load as torchaudio_sox_load
from TTS.tts.models.xtts import load_audio

torch.set_num_threads(1)

Expand Down Expand Up @@ -50,31 +47,6 @@ def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate,
return rel_clip, rel_clip.shape[-1], cond_idxs


def load_audio(audiopath, sampling_rate):
# better load setting following: https://github.com/faroit/python_audio_loading_benchmark
if audiopath[-4:] == ".mp3":
# it uses torchaudio with sox backend to load mp3
audio, lsr = torchaudio_sox_load(audiopath)
else:
# it uses torchaudio soundfile backend to load all the others data type
audio, lsr = torchaudio_soundfile_load(audiopath)

# stereo to mono if needed
if audio.size(0) != 1:
audio = torch.mean(audio, dim=0, keepdim=True)

if lsr != sampling_rate:
audio = torchaudio.functional.resample(audio, lsr, sampling_rate)

# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
# '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
if torch.any(audio > 10) or not torch.any(audio < 0):
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
# clip audio invalid values
audio.clip_(-1, 1)
return audio


class XTTSDataset(torch.utils.data.Dataset):
def __init__(self, config, samples, tokenizer, sample_rate, is_eval=False):
self.config = config
Expand Down
3 changes: 1 addition & 2 deletions TTS/tts/layers/xtts/trainer/gpt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,7 @@ def test_run(self, assets) -> Tuple[Dict, Dict]: # pylint: disable=W0613
self.config,
s_info["speaker_wav"],
s_info["language"],
gpt_cond_len=3,
decoder="ne_hifigan",
gpt_cond_len=3
Edresson marked this conversation as resolved.
Show resolved Hide resolved
)["wav"]
test_audios["{}-audio".format(idx)] = wav

Expand Down
77 changes: 59 additions & 18 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,31 @@ def wav_to_mel_cloning(
return mel


def load_audio(audiopath, sampling_rate):
# better load setting following: https://github.com/faroit/python_audio_loading_benchmark
if audiopath[-4:] == ".mp3":
# it uses torchaudio with sox backend to load mp3
audio, lsr = torchaudio.backend.sox_io_backend.load(audiopath)
else:
# it uses torchaudio soundfile backend to load all the others data type
audio, lsr = torchaudio.backend.soundfile_backend.load(audiopath)

# stereo to mono if needed
if audio.size(0) != 1:
audio = torch.mean(audio, dim=0, keepdim=True)

if lsr != sampling_rate:
audio = torchaudio.functional.resample(audio, lsr, sampling_rate)

# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
# '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
if torch.any(audio > 10) or not torch.any(audio < 0):
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
# clip audio invalid values
audio.clip_(-1, 1)
Edresson marked this conversation as resolved.
Show resolved Hide resolved
return audio


def pad_or_truncate(t, length):
"""
Ensure a given tensor t has a specified sequence length by either padding it with zeros or clipping it.
Expand Down Expand Up @@ -336,7 +361,7 @@ def get_gpt_cond_latents(self, audio, sr, length: int = 3):
"""Compute the conditioning latents for the GPT model from the given audio.

Args:
audio_path (str): Path to the audio file.
audio (tensor): audio tensor.
sr (int): Sample rate of the audio.
length (int): Length of the audio in seconds. Defaults to 3.
"""
Expand Down Expand Up @@ -404,20 +429,41 @@ def get_conditioning_latents(
max_ref_length=10,
librosa_trim_db=None,
sound_norm_refs=False,
load_sr=24000,
):
# deal with multiples references
if not isinstance(audio_path, list):
audio_paths = [audio_path]
else:
audio_paths = audio_path

speaker_embeddings = []
audios = []
speaker_embedding = None
for file_path in audio_paths:
# load the audio in 24khz to avoid issued with multiple sr references
audio = load_audio(file_path, load_sr)
audio = audio[:, : load_sr * max_ref_length].to(self.device)
if audio.shape[0] > 1:
audio = audio.mean(0, keepdim=True)
if sound_norm_refs:
audio = (audio / torch.abs(audio).max()) * 0.75
if librosa_trim_db is not None:
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]

speaker_embedding = self.get_speaker_embedding(audio, load_sr)
speaker_embeddings.append(speaker_embedding)

audios.append(audio)

# use a merge of all references for gpt cond latents
full_audio = torch.cat(audios, dim=-1)
gpt_cond_latents = self.get_gpt_cond_latents(full_audio, load_sr, length=gpt_cond_len) # [1, 1024, T]

if speaker_embeddings:
speaker_embedding = torch.stack(speaker_embeddings)
speaker_embedding = speaker_embedding.mean(dim=0)

audio, sr = torchaudio.load(audio_path)
audio = audio[:, : sr * max_ref_length].to(self.device)
if audio.shape[0] > 1:
audio = audio.mean(0, keepdim=True)
if sound_norm_refs:
audio = (audio / torch.abs(audio).max()) * 0.75
if librosa_trim_db is not None:
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]

speaker_embedding = self.get_speaker_embedding(audio, sr)
gpt_cond_latents = self.get_gpt_cond_latents(audio, sr, length=gpt_cond_len) # [1, 1024, T]
return gpt_cond_latents, speaker_embedding

def synthesize(self, text, config, speaker_wav, language, **kwargs):
Expand All @@ -426,7 +472,7 @@ def synthesize(self, text, config, speaker_wav, language, **kwargs):
Args:
text (str): Input text.
config (XttsConfig): Config with inference parameters.
speaker_wav (str): Path to the speaker audio file for cloning.
speaker_wav (list): List of paths to the speaker audio files to be used for cloning.
language (str): Language ID of the speaker.
**kwargs: Inference settings. See `inference()`.

Expand All @@ -436,11 +482,6 @@ def synthesize(self, text, config, speaker_wav, language, **kwargs):
as latents used at inference.

"""

# Make the synthesizer happy 🥳
if isinstance(speaker_wav, list):
speaker_wav = speaker_wav[0]
Edresson marked this conversation as resolved.
Show resolved Hide resolved

return self.inference_with_config(text, config, ref_audio_path=speaker_wav, language=language, **kwargs)

def inference_with_config(self, text, config, ref_audio_path, language, **kwargs):
Expand Down
41 changes: 37 additions & 4 deletions docs/source/models/xtts.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,33 @@ You can also mail us at [email protected].
### Inference
#### 🐸TTS API

##### Single reference
```python
from TTS.api import TTS
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=True)

# generate speech by cloning a voice using default settings
tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
file_path="output.wav",
speaker_wav="/path/to/target/speaker.wav",
speaker_wav=["/path/to/target/speaker.wav"],
language="en")
```

##### Multiple references
```python
from TTS.api import TTS
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=True)

# generate speech by cloning a voice using default settings
tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
file_path="output.wav",
speaker_wav=["/path/to/target/speaker.wav", "/path/to/target/speaker_2.wav", "/path/to/target/speaker_3.wav"],
language="en")
```

#### 🐸TTS Command line

##### Single reference
```console
tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 \
--text "Bugün okula gitmek istemiyorum." \
Expand All @@ -60,6 +74,25 @@ tts.tts_to_file(text="It took me quite a long time to develop a voice, and now t
--use_cuda true
```

##### Multiple references
```console
tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 \
--text "Bugün okula gitmek istemiyorum." \
--speaker_wav /path/to/target/speaker.wav /path/to/target/speaker_2.wav /path/to/target/speaker_3.wav \
--language_idx tr \
--use_cuda true
```
or for all wav files in a directory you can use:

```console
tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 \
--text "Bugün okula gitmek istemiyorum." \
--speaker_wav /path/to/target/*.wav \
--language_idx tr \
--use_cuda true
```


#### model directly

If you want to be able to run with `use_deepspeed=True` and enjoy the speedup, you need to install deepspeed first.
Expand All @@ -83,7 +116,7 @@ model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=Tru
model.cuda()

print("Computing speaker latents...")
gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path="reference.wav")
gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path=["reference.wav"])

print("Inference...")
out = model.inference(
Expand Down Expand Up @@ -120,7 +153,7 @@ model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=Tru
model.cuda()

print("Computing speaker latents...")
gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path="reference.wav")
gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path=["reference.wav"])

print("Inference...")
t0 = time.time()
Expand Down Expand Up @@ -177,7 +210,7 @@ model.load_checkpoint(config, checkpoint_path=XTTS_CHECKPOINT, vocab_path=TOKENI
model.cuda()

print("Computing speaker latents...")
gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path=SPEAKER_REFERENCE)
gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path=[SPEAKER_REFERENCE])

print("Inference...")
out = model.inference(
Expand Down
4 changes: 2 additions & 2 deletions recipes/ljspeech/xtts_v1/train_gpt_xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@


# Training sentences generations
SPEAKER_REFERENCE = (
SPEAKER_REFERENCE = [
"./tests/data/ljspeech/wavs/LJ001-0002.wav" # speaker reference to be used in training test sentences
)
]
LANGUAGE = config_dataset.language


Expand Down
13 changes: 6 additions & 7 deletions recipes/ljspeech/xtts_v2/train_gpt_xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,14 @@
print(" > Downloading DVAE files!")
ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True)

# ToDo: Update links for XTTS v2.0

# Download XTTS v2.0 checkpoint if needed
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v2.0/vocab.json"
XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v2.0/model.pth"
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json"
XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth"

# XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning.
TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, TOKENIZER_FILE_LINK.split("/")[-1]) # vocab.json file
XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, XTTS_CHECKPOINT_LINK.split("/")[-1]) # model.pth file
TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(TOKENIZER_FILE_LINK)) # vocab.json file
XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CHECKPOINT_LINK)) # model.pth file

# download XTTS v2.0 files if needed
if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT):
Expand All @@ -72,9 +71,9 @@


# Training sentences generations
SPEAKER_REFERENCE = (
SPEAKER_REFERENCE = [
"./tests/data/ljspeech/wavs/LJ001-0002.wav" # speaker reference to be used in training test sentences
)
]
LANGUAGE = config_dataset.language


Expand Down
2 changes: 1 addition & 1 deletion tests/xtts_tests/test_xtts_gpt_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@


# Training sentences generations
SPEAKER_REFERENCE = "tests/data/ljspeech/wavs/LJ001-0002.wav" # speaker reference to be used in training test sentences
SPEAKER_REFERENCE = ["tests/data/ljspeech/wavs/LJ001-0002.wav"] # speaker reference to be used in training test sentences
LANGUAGE = config_dataset.language


Expand Down
4 changes: 3 additions & 1 deletion tests/xtts_tests/test_xtts_v2-0_gpt_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@


# Training sentences generations
SPEAKER_REFERENCE = "tests/data/ljspeech/wavs/LJ001-0002.wav" # speaker reference to be used in training test sentences
SPEAKER_REFERENCE = ["tests/data/ljspeech/wavs/LJ001-0002.wav"] # speaker reference to be used in training test sentences
LANGUAGE = config_dataset.language


Expand Down Expand Up @@ -87,7 +87,9 @@
gpt_use_masking_gt_prompt_approach=True,
gpt_use_perceiver_resampler=True,
)

audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000)

config = GPTTrainerConfig(
epochs=1,
output_path=OUT_PATH,
Expand Down
11 changes: 7 additions & 4 deletions tests/zoo_tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def test_xtts_streaming():
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts

speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
speaker_wav = [os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")]
speaker_wav_2 = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0002.wav")
speaker_wav.append(speaker_wav_2)
model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1")
config = XttsConfig()
config.load_json(os.path.join(model_path, "config.json"))
Expand Down Expand Up @@ -131,20 +133,21 @@ def test_xtts_v2():
"""XTTS is too big to run on github actions. We need to test it locally"""
output_path = os.path.join(get_tests_output_path(), "output.wav")
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
speaker_wav_2 = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0002.wav")
use_gpu = torch.cuda.is_available()
if use_gpu:
run_cli(
"yes | "
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 "
f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True '
f'--speaker_wav "{speaker_wav}" --language_idx "en"'
f'--speaker_wav "{speaker_wav}" "{speaker_wav_2}" "--language_idx "en"'
)
else:
run_cli(
"yes | "
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 "
f'--text "This is an example." --out_path "{output_path}" --progress_bar False '
f'--speaker_wav "{speaker_wav}" --language_idx "en"'
f'--speaker_wav "{speaker_wav}" "{speaker_wav_2}" --language_idx "en"'
)


Expand All @@ -153,7 +156,7 @@ def test_xtts_v2_streaming():
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts

speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
speaker_wav = [os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")]
model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v2")
config = XttsConfig()
config.load_json(os.path.join(model_path, "config.json"))
Expand Down
Loading