Skip to content

Commit

Permalink
XTTS v2.0 (#3137)
Browse files Browse the repository at this point in the history
* Implement most similar ref training approach

* Use non-enhanced hifigan for test samples

* Add Perceiver

* Update GPT Trainer for perceiver support

* Update XTTS docs

* Bug fix masking with XTTS perceiver

* Bug fix on gpt forward

* Bug Fix on XTTS v2.0 training

* Add XTTS v2.0 unit tests

* Add XTTS v2.0 inference unit tests

* Bug Fix on diffusion inference

* Add XTTS v2.0 training recipe

* Placeholder model entry

* Add cloning params to config

* Make prompt embedding configurable

* Make cloning configurable

* Cheap fix for a cheaper fix

* Prevent resampling

* Update model entry

* Update docs

* Update requirements

* Code linting

* Add xtts v2 to sep tests

* Bug fix on XTTS get_gpt_cond_latents

* Bug fix on rebase

* Make style

* Bug fix in Japenese tokenizer

* Add num2words to deps

* Remove unused kwarg and added num_beams=1 as default

---------

Co-authored-by: Eren G??lge <[email protected]>
  • Loading branch information
Edresson and erogol committed Nov 6, 2023
1 parent 38f6f8f commit e45227d
Show file tree
Hide file tree
Showing 20 changed files with 1,782 additions and 661 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,4 @@ wandb
depot/*
coqui_recipes/*
local_scripts/*
coqui_demos/*
14 changes: 14 additions & 0 deletions TTS/.models.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,20 @@
"tts_models": {
"multilingual": {
"multi-dataset": {
"xtts_v2": {
"description": "XTTS-v2 by Coqui with 16 languages.",
"hf_url": [
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/config.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/hash.md5"
],
"default_vocoder": null,
"commit": "480a6cdf7",
"license": "CPML",
"contact": "[email protected]",
"tos_required": true
},
"xtts_v1": {
"description": "XTTS-v1 by Coqui with 13 languages and cross-language voice cloning.",
"hf_url": [
Expand Down
34 changes: 33 additions & 1 deletion TTS/tts/configs/xtts_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ class XttsConfig(BaseTTSConfig):
decoder_sampler (str):
Diffusion sampler to be used. `ddim` or `dpm++2m`. Defaults to `ddim`.
gpt_cond_len (int):
Secs audio to be used as conditioning for the autoregressive model. Defaults to `3`.
max_ref_len (int):
Maximum number of seconds of audio to be used as conditioning for the decoder. Defaults to `10`.
sound_norm_refs (bool):
Whether to normalize the conditioning audio. Defaults to `False`.
Note:
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
Expand All @@ -74,7 +84,24 @@ class XttsConfig(BaseTTSConfig):
audio: XttsAudioConfig = field(default_factory=XttsAudioConfig)
model_dir: str = None
languages: List[str] = field(
default_factory=lambda: ["en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn"]
default_factory=lambda: [
"en",
"es",
"fr",
"de",
"it",
"pt",
"pl",
"tr",
"ru",
"nl",
"cs",
"ar",
"zh-cn",
"hu",
"ko",
"ja",
]
)

# inference params
Expand All @@ -88,3 +115,8 @@ class XttsConfig(BaseTTSConfig):
num_gpt_outputs: int = 1
decoder_iterations: int = 30
decoder_sampler: str = "ddim"

# cloning
gpt_cond_len: int = 3
max_ref_len: int = 10
sound_norm_refs: bool = False
21 changes: 5 additions & 16 deletions TTS/tts/layers/tortoise/dpm_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,21 +562,15 @@ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type
if order == 3:
K = steps // 3 + 1
if steps % 3 == 0:
orders = [
3,
] * (
orders = [3,] * (
K - 2
) + [2, 1]
elif steps % 3 == 1:
orders = [
3,
] * (
orders = [3,] * (
K - 1
) + [1]
else:
orders = [
3,
] * (
orders = [3,] * (
K - 1
) + [2]
elif order == 2:
Expand All @@ -587,9 +581,7 @@ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type
] * K
else:
K = steps // 2 + 1
orders = [
2,
] * (
orders = [2,] * (
K - 1
) + [1]
elif order == 1:
Expand Down Expand Up @@ -1448,10 +1440,7 @@ def sample(
model_prev_list[-1] = self.model_fn(x, t)
elif method in ["singlestep", "singlestep_fixed"]:
if method == "singlestep":
(
timesteps_outer,
orders,
) = self.get_orders_and_timesteps_for_singlestep_solver(
(timesteps_outer, orders,) = self.get_orders_and_timesteps_for_singlestep_solver(
steps=steps,
order=order,
skip_type=skip_type,
Expand Down
65 changes: 52 additions & 13 deletions TTS/tts/layers/xtts/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder
from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler


def null_position_embeddings(range, dim):
Expand Down Expand Up @@ -105,6 +106,8 @@ def __init__(
checkpointing=False,
average_conditioning_embeddings=False,
label_smoothing=0.0,
use_perceiver_resampler=False,
perceiver_cond_length_compression=256,
):
"""
Args:
Expand Down Expand Up @@ -132,13 +135,12 @@ def __init__(
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.conditioning_dropout = nn.Dropout1d(0.1)
self.average_conditioning_embeddings = average_conditioning_embeddings
self.use_perceiver_resampler = use_perceiver_resampler
self.perceiver_cond_length_compression = perceiver_cond_length_compression

self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
self.mel_embedding = nn.Embedding(self.num_audio_tokens, model_dim)

self.prompt_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
self.prompt_pos_embedding = LearnedPositionEmbeddings(24 * 9, model_dim)

(
self.gpt,
self.mel_pos_embedding,
Expand All @@ -165,9 +167,29 @@ def __init__(
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
self.mel_head = nn.Linear(model_dim, self.num_audio_tokens)

if self.use_perceiver_resampler:
# XTTS v2
self.conditioning_perceiver = PerceiverResampler(
dim=model_dim,
depth=2,
dim_context=model_dim,
num_latents=32,
dim_head=64,
heads=8,
ff_mult=4,
use_flash_attn=False,
)
else:
# XTTS v1
self.prompt_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
self.prompt_pos_embedding = LearnedPositionEmbeddings(24 * 9, model_dim)

def get_grad_norm_parameter_groups(self):
return {
"conditioning_encoder": list(self.conditioning_encoder.parameters()),
"conditioning_perceiver": list(self.conditioning_perceiver.parameters())
if self.use_perceiver_resampler
else None,
"gpt": list(self.gpt.parameters()),
"heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()),
}
Expand Down Expand Up @@ -250,11 +272,8 @@ def get_logits(
if attn_mask_text is not None:
attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1)
if prompt is not None:
if attn_mask_cond is not None:
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
else:
attn_mask_cond = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
attn_mask_cond = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)

gpt_out = self.gpt(
inputs_embeds=emb,
Expand Down Expand Up @@ -318,7 +337,6 @@ def get_prompts(self, prompt_codes):
prompt_len = 3
prompt_len = prompt_len * 24 # in frames
if prompt_codes.shape[-1] >= prompt_len:
new_prompt = []
for i in range(prompt_codes.shape[0]):
if lengths[i] < prompt_len:
start = 0
Expand All @@ -340,7 +358,9 @@ def get_style_emb(self, cond_input, return_latent=False):
if not return_latent:
if cond_input.ndim == 4:
cond_input = cond_input.squeeze(1)
conds = self.conditioning_encoder(cond_input)
conds = self.conditioning_encoder(cond_input) # (b, d, s)
if self.use_perceiver_resampler:
conds = self.conditioning_perceiver(conds.permute(0, 2, 1)).transpose(1, 2) # (b, d, 32)
else:
# already computed
conds = cond_input.unsqueeze(1)
Expand All @@ -354,6 +374,7 @@ def forward(
wav_lengths,
cond_mels=None,
cond_idxs=None,
cond_lens=None,
cond_latents=None,
return_attentions=False,
return_latent=False,
Expand All @@ -379,10 +400,24 @@ def forward(
max_text_len = text_lengths.max()
code_lengths = torch.ceil(wav_lengths / self.code_stride_len).long() + 3

if cond_lens is not None:
if self.use_perceiver_resampler:
cond_lens = cond_lens // self.perceiver_cond_length_compression
else:
cond_lens = cond_lens // self.code_stride_len

if cond_idxs is not None:
# recompute cond idxs for mel lengths
for idx, l in enumerate(code_lengths):
cond_idxs[idx] = cond_idxs[idx] / self.code_stride_len
for idx in range(cond_idxs.size(0)):
if self.use_perceiver_resampler:
cond_idxs[idx] = cond_idxs[idx] // self.perceiver_cond_length_compression
else:
cond_idxs[idx] = cond_idxs[idx] // self.code_stride_len

# ensure that the cond_mel does not have padding
# if cond_lens is not None and cond_idxs is None:
# min_cond_len = torch.min(cond_lens)
# cond_mels = cond_mels[:, :, :, :min_cond_len]

# If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes.
max_mel_len = code_lengths.max()
Expand Down Expand Up @@ -450,9 +485,13 @@ def forward(
)

if cond_idxs is not None:
# use masking approach
for idx, r in enumerate(cond_idxs):
l = r[1] - r[0]
attn_mask_cond[idx, l:] = 0.0
elif cond_lens is not None:
for idx, l in enumerate(cond_lens):
attn_mask_cond[idx, l:] = 0.0

for idx, l in enumerate(text_lengths):
attn_mask_text[idx, l + 1 :] = 0.0
Expand Down Expand Up @@ -523,7 +562,7 @@ def forward(

def inference(self, cond_latents, text_inputs, **hf_generate_kwargs):
self.compute_embeddings(cond_latents, text_inputs)
return self.generate(cond_latents, text_inputs, input_tokens=None, **hf_generate_kwargs)
return self.generate(cond_latents, text_inputs, **hf_generate_kwargs)

def compute_embeddings(
self,
Expand Down
Loading

0 comments on commit e45227d

Please sign in to comment.