diff --git a/TTS/VERSION b/TTS/VERSION index 59dad104b0..16eb94e711 100644 --- a/TTS/VERSION +++ b/TTS/VERSION @@ -1 +1 @@ -0.21.2 +0.21.3 diff --git a/TTS/demos/xtts_ft_demo/requirements.txt b/TTS/demos/xtts_ft_demo/requirements.txt new file mode 100644 index 0000000000..cb5b16f66e --- /dev/null +++ b/TTS/demos/xtts_ft_demo/requirements.txt @@ -0,0 +1,2 @@ +faster_whisper==0.9.0 +gradio==4.7.1 \ No newline at end of file diff --git a/TTS/demos/xtts_ft_demo/utils/formatter.py b/TTS/demos/xtts_ft_demo/utils/formatter.py new file mode 100644 index 0000000000..536faa0108 --- /dev/null +++ b/TTS/demos/xtts_ft_demo/utils/formatter.py @@ -0,0 +1,160 @@ +import os +import gc +import torchaudio +import pandas +from faster_whisper import WhisperModel +from glob import glob + +from tqdm import tqdm + +import torch +import torchaudio +# torch.set_num_threads(1) + +from TTS.tts.layers.xtts.tokenizer import multilingual_cleaners + +torch.set_num_threads(16) + + +import os + +audio_types = (".wav", ".mp3", ".flac") + + +def list_audios(basePath, contains=None): + # return the set of files that are valid + return list_files(basePath, validExts=audio_types, contains=contains) + +def list_files(basePath, validExts=None, contains=None): + # loop over the directory structure + for (rootDir, dirNames, filenames) in os.walk(basePath): + # loop over the filenames in the current directory + for filename in filenames: + # if the contains string is not none and the filename does not contain + # the supplied string, then ignore the file + if contains is not None and filename.find(contains) == -1: + continue + + # determine the file extension of the current file + ext = filename[filename.rfind("."):].lower() + + # check to see if the file is an audio and should be processed + if validExts is None or ext.endswith(validExts): + # construct the path to the audio and yield it + audioPath = os.path.join(rootDir, filename) + yield audioPath + +def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0.2, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None): + audio_total_size = 0 + # make sure that ooutput file exists + os.makedirs(out_path, exist_ok=True) + + # Loading Whisper + device = "cuda" if torch.cuda.is_available() else "cpu" + + print("Loading Whisper Model!") + asr_model = WhisperModel("large-v2", device=device, compute_type="float16") + + metadata = {"audio_file": [], "text": [], "speaker_name": []} + + if gradio_progress is not None: + tqdm_object = gradio_progress.tqdm(audio_files, desc="Formatting...") + else: + tqdm_object = tqdm(audio_files) + + for audio_path in tqdm_object: + wav, sr = torchaudio.load(audio_path) + # stereo to mono if needed + if wav.size(0) != 1: + wav = torch.mean(wav, dim=0, keepdim=True) + + wav = wav.squeeze() + audio_total_size += (wav.size(-1) / sr) + + segments, _ = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language) + segments = list(segments) + i = 0 + sentence = "" + sentence_start = None + first_word = True + # added all segments words in a unique list + words_list = [] + for _, segment in enumerate(segments): + words = list(segment.words) + words_list.extend(words) + + # process each word + for word_idx, word in enumerate(words_list): + if first_word: + sentence_start = word.start + # If it is the first sentence, add buffer or get the begining of the file + if word_idx == 0: + sentence_start = max(sentence_start - buffer, 0) # Add buffer to the sentence start + else: + # get previous sentence end + previous_word_end = words_list[word_idx - 1].end + # add buffer or get the silence midle between the previous sentence and the current one + sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start)/2) + + sentence = word.word + first_word = False + else: + sentence += word.word + + if word.word[-1] in ["!", ".", "?"]: + sentence = sentence[1:] + # Expand number and abbreviations plus normalization + sentence = multilingual_cleaners(sentence, target_language) + audio_file_name, _ = os.path.splitext(os.path.basename(audio_path)) + + audio_file = f"wavs/{audio_file_name}_{str(i).zfill(8)}.wav" + + # Check for the next word's existence + if word_idx + 1 < len(words_list): + next_word_start = words_list[word_idx + 1].start + else: + # If don't have more words it means that it is the last sentence then use the audio len as next word start + next_word_start = (wav.shape[0] - 1) / sr + + # Average the current word end and next word start + word_end = min((word.end + next_word_start) / 2, word.end + buffer) + + absoulte_path = os.path.join(out_path, audio_file) + os.makedirs(os.path.dirname(absoulte_path), exist_ok=True) + i += 1 + first_word = True + + audio = wav[int(sr*sentence_start):int(sr*word_end)].unsqueeze(0) + # if the audio is too short ignore it (i.e < 0.33 seconds) + if audio.size(-1) >= sr/3: + torchaudio.save(absoulte_path, + audio, + sr + ) + else: + continue + + metadata["audio_file"].append(audio_file) + metadata["text"].append(sentence) + metadata["speaker_name"].append(speaker_name) + + df = pandas.DataFrame(metadata) + df = df.sample(frac=1) + num_val_samples = int(len(df)*eval_percentage) + + df_eval = df[:num_val_samples] + df_train = df[num_val_samples:] + + df_train = df_train.sort_values('audio_file') + train_metadata_path = os.path.join(out_path, "metadata_train.csv") + df_train.to_csv(train_metadata_path, sep="|", index=False) + + eval_metadata_path = os.path.join(out_path, "metadata_eval.csv") + df_eval = df_eval.sort_values('audio_file') + df_eval.to_csv(eval_metadata_path, sep="|", index=False) + + # deallocate VRAM and RAM + del asr_model, df_train, df_eval, df, metadata + gc.collect() + + return train_metadata_path, eval_metadata_path, audio_total_size \ No newline at end of file diff --git a/TTS/demos/xtts_ft_demo/utils/gpt_train.py b/TTS/demos/xtts_ft_demo/utils/gpt_train.py new file mode 100644 index 0000000000..a98765c3e7 --- /dev/null +++ b/TTS/demos/xtts_ft_demo/utils/gpt_train.py @@ -0,0 +1,172 @@ +import os +import gc + +from trainer import Trainer, TrainerArgs + +from TTS.config.shared_configs import BaseDatasetConfig +from TTS.tts.datasets import load_tts_samples +from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig +from TTS.utils.manage import ModelManager + + +def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path, max_audio_length=255995): + # Logging parameters + RUN_NAME = "GPT_XTTS_FT" + PROJECT_NAME = "XTTS_trainer" + DASHBOARD_LOGGER = "tensorboard" + LOGGER_URI = None + + # Set here the path that the checkpoints will be saved. Default: ./run/training/ + OUT_PATH = os.path.join(output_path, "run", "training") + + # Training Parameters + OPTIMIZER_WD_ONLY_ON_WEIGHTS = True # for multi-gpu training please make it False + START_WITH_EVAL = False # if True it will star with evaluation + BATCH_SIZE = batch_size # set here the batch size + GRAD_ACUMM_STEPS = grad_acumm # set here the grad accumulation steps + + + # Define here the dataset that you want to use for the fine-tuning on. + config_dataset = BaseDatasetConfig( + formatter="coqui", + dataset_name="ft_dataset", + path=os.path.dirname(train_csv), + meta_file_train=train_csv, + meta_file_val=eval_csv, + language=language, + ) + + # Add here the configs of the datasets + DATASETS_CONFIG_LIST = [config_dataset] + + # Define the path where XTTS v2.0.1 files will be downloaded + CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/") + os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True) + + + # DVAE files + DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth" + MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth" + + # Set the path to the downloaded files + DVAE_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(DVAE_CHECKPOINT_LINK)) + MEL_NORM_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(MEL_NORM_LINK)) + + # download DVAE files if needed + if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE): + print(" > Downloading DVAE files!") + ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True) + + + # Download XTTS v2.0 checkpoint if needed + TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json" + XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth" + XTTS_CONFIG_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/config.json" + + # XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning. + TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(TOKENIZER_FILE_LINK)) # vocab.json file + XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CHECKPOINT_LINK)) # model.pth file + XTTS_CONFIG_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CONFIG_LINK)) # config.json file + + # download XTTS v2.0 files if needed + if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT): + print(" > Downloading XTTS v2.0 files!") + ModelManager._download_model_files( + [TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK, XTTS_CONFIG_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True + ) + + # init args and config + model_args = GPTArgs( + max_conditioning_length=132300, # 6 secs + min_conditioning_length=66150, # 3 secs + debug_loading_failures=False, + max_wav_length=max_audio_length, # ~11.6 seconds + max_text_length=200, + mel_norm_file=MEL_NORM_FILE, + dvae_checkpoint=DVAE_CHECKPOINT, + xtts_checkpoint=XTTS_CHECKPOINT, # checkpoint path of the model that you want to fine-tune + tokenizer_file=TOKENIZER_FILE, + gpt_num_audio_tokens=1026, + gpt_start_audio_token=1024, + gpt_stop_audio_token=1025, + gpt_use_masking_gt_prompt_approach=True, + gpt_use_perceiver_resampler=True, + ) + # define audio config + audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000) + # training parameters config + config = GPTTrainerConfig( + epochs=num_epochs, + output_path=OUT_PATH, + model_args=model_args, + run_name=RUN_NAME, + project_name=PROJECT_NAME, + run_description=""" + GPT XTTS training + """, + dashboard_logger=DASHBOARD_LOGGER, + logger_uri=LOGGER_URI, + audio=audio_config, + batch_size=BATCH_SIZE, + batch_group_size=48, + eval_batch_size=BATCH_SIZE, + num_loader_workers=8, + eval_split_max_size=256, + print_step=50, + plot_step=100, + log_model_step=100, + save_step=1000, + save_n_checkpoints=1, + save_checkpoints=True, + # target_loss="loss", + print_eval=False, + # Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters. + optimizer="AdamW", + optimizer_wd_only_on_weights=OPTIMIZER_WD_ONLY_ON_WEIGHTS, + optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2}, + lr=5e-06, # learning rate + lr_scheduler="MultiStepLR", + # it was adjusted accordly for the new step scheme + lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1}, + test_sentences=[], + ) + + # init the model from config + model = GPTTrainer.init_from_config(config) + + # load training samples + train_samples, eval_samples = load_tts_samples( + DATASETS_CONFIG_LIST, + eval_split=True, + eval_split_max_size=config.eval_split_max_size, + eval_split_size=config.eval_split_size, + ) + + # init the trainer and 🚀 + trainer = Trainer( + TrainerArgs( + restore_path=None, # xtts checkpoint is restored via xtts_checkpoint key so no need of restore it using Trainer restore_path parameter + skip_train_epoch=False, + start_with_eval=START_WITH_EVAL, + grad_accum_steps=GRAD_ACUMM_STEPS, + ), + config, + output_path=OUT_PATH, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, + ) + trainer.fit() + + # get the longest text audio file to use as speaker reference + samples_len = [len(item["text"].split(" ")) for item in train_samples] + longest_text_idx = samples_len.index(max(samples_len)) + speaker_ref = train_samples[longest_text_idx]["audio_file"] + + trainer_out_path = trainer.output_path + + # deallocate VRAM and RAM + del model, trainer, train_samples, eval_samples + gc.collect() + + return XTTS_CONFIG_FILE, XTTS_CHECKPOINT, TOKENIZER_FILE, trainer_out_path, speaker_ref diff --git a/TTS/demos/xtts_ft_demo/xtts_demo.py b/TTS/demos/xtts_ft_demo/xtts_demo.py new file mode 100644 index 0000000000..ebb11f29d1 --- /dev/null +++ b/TTS/demos/xtts_ft_demo/xtts_demo.py @@ -0,0 +1,415 @@ +import argparse +import os +import sys +import tempfile + +import gradio as gr +import librosa.display +import numpy as np + +import os +import torch +import torchaudio +import traceback +from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list +from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt + +from TTS.tts.configs.xtts_config import XttsConfig +from TTS.tts.models.xtts import Xtts + + +def clear_gpu_cache(): + # clear the GPU cache + if torch.cuda.is_available(): + torch.cuda.empty_cache() + +XTTS_MODEL = None +def load_model(xtts_checkpoint, xtts_config, xtts_vocab): + global XTTS_MODEL + clear_gpu_cache() + if not xtts_checkpoint or not xtts_config or not xtts_vocab: + return "You need to run the previous steps or manually set the `XTTS checkpoint path`, `XTTS config path`, and `XTTS vocab path` fields !!" + config = XttsConfig() + config.load_json(xtts_config) + XTTS_MODEL = Xtts.init_from_config(config) + print("Loading XTTS model! ") + XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab, use_deepspeed=False) + if torch.cuda.is_available(): + XTTS_MODEL.cuda() + + print("Model Loaded!") + return "Model Loaded!" + +def run_tts(lang, tts_text, speaker_audio_file): + if XTTS_MODEL is None or not speaker_audio_file: + return "You need to run the previous step to load the model !!", None, None + + gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs) + out = XTTS_MODEL.inference( + text=tts_text, + language=lang, + gpt_cond_latent=gpt_cond_latent, + speaker_embedding=speaker_embedding, + temperature=XTTS_MODEL.config.temperature, # Add custom parameters here + length_penalty=XTTS_MODEL.config.length_penalty, + repetition_penalty=XTTS_MODEL.config.repetition_penalty, + top_k=XTTS_MODEL.config.top_k, + top_p=XTTS_MODEL.config.top_p, + ) + + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp: + out["wav"] = torch.tensor(out["wav"]).unsqueeze(0) + out_path = fp.name + torchaudio.save(out_path, out["wav"], 24000) + + return "Speech generated !", out_path, speaker_audio_file + + + + +# define a logger to redirect +class Logger: + def __init__(self, filename="log.out"): + self.log_file = filename + self.terminal = sys.stdout + self.log = open(self.log_file, "w") + + def write(self, message): + self.terminal.write(message) + self.log.write(message) + + def flush(self): + self.terminal.flush() + self.log.flush() + + def isatty(self): + return False + +# redirect stdout and stderr to a file +sys.stdout = Logger() +sys.stderr = sys.stdout + + +# logging.basicConfig(stream=sys.stdout, level=logging.INFO) +import logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[ + logging.StreamHandler(sys.stdout) + ] +) + +def read_logs(): + sys.stdout.flush() + with open(sys.stdout.log_file, "r") as f: + return f.read() + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description="""XTTS fine-tuning demo\n\n""" + """ + Example runs: + python3 TTS/demos/xtts_ft_demo/xtts_demo.py --port + """, + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + "--port", + type=int, + help="Port to run the gradio demo. Default: 5003", + default=5003, + ) + parser.add_argument( + "--out_path", + type=str, + help="Output path (where data and checkpoints will be saved) Default: /tmp/xtts_ft/", + default="/tmp/xtts_ft/", + ) + + parser.add_argument( + "--num_epochs", + type=int, + help="Number of epochs to train. Default: 10", + default=10, + ) + parser.add_argument( + "--batch_size", + type=int, + help="Batch size. Default: 4", + default=4, + ) + parser.add_argument( + "--grad_acumm", + type=int, + help="Grad accumulation steps. Default: 1", + default=1, + ) + parser.add_argument( + "--max_audio_length", + type=int, + help="Max permitted audio size in seconds. Default: 11", + default=11, + ) + + args = parser.parse_args() + + with gr.Blocks() as demo: + with gr.Tab("1 - Data processing"): + out_path = gr.Textbox( + label="Output path (where data and checkpoints will be saved):", + value=args.out_path, + ) + # upload_file = gr.Audio( + # sources="upload", + # label="Select here the audio files that you want to use for XTTS trainining !", + # type="filepath", + # ) + upload_file = gr.File( + file_count="multiple", + label="Select here the audio files that you want to use for XTTS trainining (Supported formats: wav, mp3, and flac)", + ) + lang = gr.Dropdown( + label="Dataset Language", + value="en", + choices=[ + "en", + "es", + "fr", + "de", + "it", + "pt", + "pl", + "tr", + "ru", + "nl", + "cs", + "ar", + "zh", + "hu", + "ko", + "ja" + ], + ) + progress_data = gr.Label( + label="Progress:" + ) + logs = gr.Textbox( + label="Logs:", + interactive=False, + ) + demo.load(read_logs, None, logs, every=1) + + prompt_compute_btn = gr.Button(value="Step 1 - Create dataset") + + def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(track_tqdm=True)): + clear_gpu_cache() + out_path = os.path.join(out_path, "dataset") + os.makedirs(out_path, exist_ok=True) + if audio_path is None: + return "You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!", "", "" + else: + try: + train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress) + except: + traceback.print_exc() + error = traceback.format_exc() + return f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}", "", "" + + clear_gpu_cache() + + # if audio total len is less than 2 minutes raise an error + if audio_total_size < 120: + message = "The sum of the duration of the audios that you provided should be at least 2 minutes!" + print(message) + return message, "", "" + + print("Dataset Processed!") + return "Dataset Processed!", train_meta, eval_meta + + with gr.Tab("2 - Fine-tuning XTTS Encoder"): + train_csv = gr.Textbox( + label="Train CSV:", + ) + eval_csv = gr.Textbox( + label="Eval CSV:", + ) + num_epochs = gr.Slider( + label="Number of epochs:", + minimum=1, + maximum=100, + step=1, + value=args.num_epochs, + ) + batch_size = gr.Slider( + label="Batch size:", + minimum=2, + maximum=512, + step=1, + value=args.batch_size, + ) + grad_acumm = gr.Slider( + label="Grad accumulation steps:", + minimum=2, + maximum=128, + step=1, + value=args.grad_acumm, + ) + max_audio_length = gr.Slider( + label="Max permitted audio size in seconds:", + minimum=2, + maximum=20, + step=1, + value=args.max_audio_length, + ) + progress_train = gr.Label( + label="Progress:" + ) + logs_tts_train = gr.Textbox( + label="Logs:", + interactive=False, + ) + demo.load(read_logs, None, logs_tts_train, every=1) + train_btn = gr.Button(value="Step 2 - Run the training") + + def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length): + clear_gpu_cache() + if not train_csv or not eval_csv: + return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", "" + try: + # convert seconds to waveform frames + max_audio_length = int(max_audio_length * 22050) + config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path, max_audio_length=max_audio_length) + except: + traceback.print_exc() + error = traceback.format_exc() + return f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", "", "", "", "" + + # copy original files to avoid parameters changes issues + os.system(f"cp {config_path} {exp_path}") + os.system(f"cp {vocab_file} {exp_path}") + + ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth") + print("Model training done!") + clear_gpu_cache() + return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_wav + + with gr.Tab("3 - Inference"): + with gr.Row(): + with gr.Column() as col1: + xtts_checkpoint = gr.Textbox( + label="XTTS checkpoint path:", + value="", + ) + xtts_config = gr.Textbox( + label="XTTS config path:", + value="", + ) + + xtts_vocab = gr.Textbox( + label="XTTS vocab path:", + value="", + ) + progress_load = gr.Label( + label="Progress:" + ) + load_btn = gr.Button(value="Step 3 - Load Fine-tuned XTTS model") + + with gr.Column() as col2: + speaker_reference_audio = gr.Textbox( + label="Speaker reference audio:", + value="", + ) + tts_language = gr.Dropdown( + label="Language", + value="en", + choices=[ + "en", + "es", + "fr", + "de", + "it", + "pt", + "pl", + "tr", + "ru", + "nl", + "cs", + "ar", + "zh", + "hu", + "ko", + "ja", + ] + ) + tts_text = gr.Textbox( + label="Input Text.", + value="This model sounds really good and above all, it's reasonably fast.", + ) + tts_btn = gr.Button(value="Step 4 - Inference") + + with gr.Column() as col3: + progress_gen = gr.Label( + label="Progress:" + ) + tts_output_audio = gr.Audio(label="Generated Audio.") + reference_audio = gr.Audio(label="Reference audio used.") + + prompt_compute_btn.click( + fn=preprocess_dataset, + inputs=[ + upload_file, + lang, + out_path, + ], + outputs=[ + progress_data, + train_csv, + eval_csv, + ], + ) + + + train_btn.click( + fn=train_model, + inputs=[ + lang, + train_csv, + eval_csv, + num_epochs, + batch_size, + grad_acumm, + out_path, + max_audio_length, + ], + outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio], + ) + + load_btn.click( + fn=load_model, + inputs=[ + xtts_checkpoint, + xtts_config, + xtts_vocab + ], + outputs=[progress_load], + ) + + tts_btn.click( + fn=run_tts, + inputs=[ + tts_language, + tts_text, + speaker_reference_audio, + ], + outputs=[progress_gen, tts_output_audio, reference_audio], + ) + + demo.launch( + share=True, + debug=False, + server_port=args.port, + server_name="0.0.0.0" + ) diff --git a/TTS/tts/layers/xtts/trainer/gpt_trainer.py b/TTS/tts/layers/xtts/trainer/gpt_trainer.py index 61222dac8a..6276f60af6 100644 --- a/TTS/tts/layers/xtts/trainer/gpt_trainer.py +++ b/TTS/tts/layers/xtts/trainer/gpt_trainer.py @@ -225,11 +225,11 @@ def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels @torch.no_grad() def test_run(self, assets) -> Tuple[Dict, Dict]: # pylint: disable=W0613 + test_audios = {} if self.config.test_sentences: # init gpt for inference mode self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False) self.xtts.gpt.eval() - test_audios = {} print(" | > Synthesizing test sentences.") for idx, s_info in enumerate(self.config.test_sentences): wav = self.xtts.synthesize( diff --git a/TTS/tts/layers/xtts/zh_num2words.py b/TTS/tts/layers/xtts/zh_num2words.py index ea6d98d3da..e59ccb6630 100644 --- a/TTS/tts/layers/xtts/zh_num2words.py +++ b/TTS/tts/layers/xtts/zh_num2words.py @@ -65,7 +65,7 @@ CN_PUNCS = CN_PUNCS_STOP + CN_PUNCS_NONSTOP PUNCS = CN_PUNCS + string.punctuation -PUNCS_TRANSFORM = str.maketrans(PUNCS, " " * len(PUNCS), "") # replace puncs with space +PUNCS_TRANSFORM = str.maketrans(PUNCS, "," * len(PUNCS), "") # replace puncs with English comma # https://zh.wikipedia.org/wiki/全行和半行 diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 208ec4d561..6b8cc59101 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -272,6 +272,11 @@ def get_gpt_cond_latents(self, audio, sr, length: int = 30, chunk_length: int = style_embs = [] for i in range(0, audio.shape[1], 22050 * chunk_length): audio_chunk = audio[:, i : i + 22050 * chunk_length] + + # if the chunk is too short ignore it + if audio_chunk.size(-1) < 22050 * 0.33: + continue + mel_chunk = wav_to_mel_cloning( audio_chunk, mel_norms=self.mel_stats.cpu(), diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index 3952504d0b..463b840242 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -332,9 +332,9 @@ def _set_model_item(self, model_name): def ask_tos(model_full_path): """Ask the user to agree to the terms of service""" tos_path = os.path.join(model_full_path, "tos_agreed.txt") - print(" > You must agree to the terms of service to use this model.") - print(" | > Please see the terms of service at https://coqui.ai/cpml.txt") - print(' | > "I have read, understood and agreed to the Terms and Conditions." - [y/n]') + print(" > You must confirm the following:") + print(' | > "I have purchased a commercial license from Coqui: licensing@coqui.ai"') + print(' | > "Otherwise, I agree to the terms of the non-commercial CPML: https://coqui.ai/cpml" - [y/n]') answer = input(" | | > ") if answer.lower() == "y": with open(tos_path, "w", encoding="utf-8") as f: diff --git a/docs/source/configuration.md b/docs/source/configuration.md index cde7e073e9..ada61e16db 100644 --- a/docs/source/configuration.md +++ b/docs/source/configuration.md @@ -56,4 +56,4 @@ ModelConfig() In the example above, ```ModelConfig()``` is the final configuration that the model receives and it has all the fields necessary for the model. -We host pre-defined model configurations under ```TTS//configs/```.Although we recommend a unified config class, you can decompose it as you like as for your custom models as long as all the fields for the trainer, model, and inference APIs are provided. \ No newline at end of file +We host pre-defined model configurations under ```TTS//configs/```. Although we recommend a unified config class, you can decompose it as you like as for your custom models as long as all the fields for the trainer, model, and inference APIs are provided. diff --git a/docs/source/finetuning.md b/docs/source/finetuning.md index c236260d0c..069f565137 100644 --- a/docs/source/finetuning.md +++ b/docs/source/finetuning.md @@ -21,7 +21,7 @@ them and fine-tune it for your own dataset. This will help you in two main ways: Fine-tuning comes to the rescue in this case. You can take one of our pre-trained models and fine-tune it on your own speech dataset and achieve reasonable results with only a couple of hours of data. - However, note that, fine-tuning does not ensure great results. The model performance is still depends on the + However, note that, fine-tuning does not ensure great results. The model performance still depends on the {ref}`dataset quality ` and the hyper-parameters you choose for fine-tuning. Therefore, it still takes a bit of tinkering. @@ -41,7 +41,7 @@ them and fine-tune it for your own dataset. This will help you in two main ways: tts --list_models ``` - The command above lists the the models in a naming format as ```///```. + The command above lists the models in a naming format as ```///```. Or you can manually check the `.model.json` file in the project directory. diff --git a/docs/source/formatting_your_dataset.md b/docs/source/formatting_your_dataset.md index 796c7b6d06..23c497d0bf 100644 --- a/docs/source/formatting_your_dataset.md +++ b/docs/source/formatting_your_dataset.md @@ -7,7 +7,7 @@ If you have a single audio file and you need to split it into clips, there are d It is also important to use a lossless audio file format to prevent compression artifacts. We recommend using `wav` file format. -Let's assume you created the audio clips and their transcription. You can collect all your clips under a folder. Let's call this folder `wavs`. +Let's assume you created the audio clips and their transcription. You can collect all your clips in a folder. Let's call this folder `wavs`. ``` /wavs @@ -17,7 +17,7 @@ Let's assume you created the audio clips and their transcription. You can collec ... ``` -You can either create separate transcription files for each clip or create a text file that maps each audio clip to its transcription. In this file, each column must be delimitered by a special character separating the audio file name, the transcription and the normalized transcription. And make sure that the delimiter is not used in the transcription text. +You can either create separate transcription files for each clip or create a text file that maps each audio clip to its transcription. In this file, each column must be delimited by a special character separating the audio file name, the transcription and the normalized transcription. And make sure that the delimiter is not used in the transcription text. We recommend the following format delimited by `|`. In the following example, `audio1`, `audio2` refer to files `audio1.wav`, `audio2.wav` etc. @@ -55,7 +55,7 @@ For more info about dataset qualities and properties check our [post](https://gi After you collect and format your dataset, you need to check two things. Whether you need a `formatter` and a `text_cleaner`. The `formatter` loads the text file (created above) as a list and the `text_cleaner` performs a sequence of text normalization operations that converts the raw text into the spoken representation (e.g. converting numbers to text, acronyms, and symbols to the spoken format). -If you use a different dataset format then the LJSpeech or the other public datasets that 🐸TTS supports, then you need to write your own `formatter`. +If you use a different dataset format than the LJSpeech or the other public datasets that 🐸TTS supports, then you need to write your own `formatter`. If your dataset is in a new language or it needs special normalization steps, then you need a new `text_cleaner`. diff --git a/docs/source/implementing_a_new_language_frontend.md b/docs/source/implementing_a_new_language_frontend.md index f4f6a04a5f..2041352d64 100644 --- a/docs/source/implementing_a_new_language_frontend.md +++ b/docs/source/implementing_a_new_language_frontend.md @@ -2,11 +2,11 @@ - Language frontends are located under `TTS.tts.utils.text` - Each special language has a separate folder. -- Each folder containst all the utilities for processing the text input. +- Each folder contains all the utilities for processing the text input. - `TTS.tts.utils.text.phonemizers` contains the main phonemizer for a language. This is the class that uses the utilities from the previous step and used to convert the text to phonemes or graphemes for the model. - After you implement your phonemizer, you need to add it to the `TTS/tts/utils/text/phonemizers/__init__.py` to be able to map the language code in the model config - `config.phoneme_language` - to the phonemizer class and initiate the phonemizer automatically. - You should also add tests to `tests/text_tests` if you want to make a PR. -We suggest you to check the available implementations as reference. Good luck! \ No newline at end of file +We suggest you to check the available implementations as reference. Good luck! diff --git a/docs/source/implementing_a_new_model.md b/docs/source/implementing_a_new_model.md index e2a0437e9a..1bf7a8822e 100644 --- a/docs/source/implementing_a_new_model.md +++ b/docs/source/implementing_a_new_model.md @@ -145,7 +145,7 @@ class MyModel(BaseTTS): Args: ap (AudioProcessor): audio processor used at training. batch (Dict): Model inputs used at the previous training step. - outputs (Dict): Model outputs generated at the previoud training step. + outputs (Dict): Model outputs generated at the previous training step. Returns: Tuple[Dict, np.ndarray]: training plots and output waveform. @@ -183,7 +183,7 @@ class MyModel(BaseTTS): ... def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]: - """Setup an return optimizer or optimizers.""" + """Setup a return optimizer or optimizers.""" pass def get_lr(self) -> Union[float, List[float]]: diff --git a/docs/source/marytts.md b/docs/source/marytts.md index 81d547107d..9091ca330f 100644 --- a/docs/source/marytts.md +++ b/docs/source/marytts.md @@ -2,13 +2,13 @@ ## What is Mary-TTS? -[Mary (Modular Architecture for Research in sYynthesis) Text-to-Speech](http://mary.dfki.de/) is an open-source (GNU LGPL license), multilingual Text-to-Speech Synthesis platform written in Java. It was originally developed as a collaborative project of [DFKI’s](http://www.dfki.de/web) Language Technology Lab and the [Institute of Phonetics](http://www.coli.uni-saarland.de/groups/WB/Phonetics/) at Saarland University, Germany. It is now maintained by the Multimodal Speech Processing Group in the [Cluster of Excellence MMCI](https://www.mmci.uni-saarland.de/) and DFKI. +[Mary (Modular Architecture for Research in sYnthesis) Text-to-Speech](http://mary.dfki.de/) is an open-source (GNU LGPL license), multilingual Text-to-Speech Synthesis platform written in Java. It was originally developed as a collaborative project of [DFKI’s](http://www.dfki.de/web) Language Technology Lab and the [Institute of Phonetics](http://www.coli.uni-saarland.de/groups/WB/Phonetics/) at Saarland University, Germany. It is now maintained by the Multimodal Speech Processing Group in the [Cluster of Excellence MMCI](https://www.mmci.uni-saarland.de/) and DFKI. MaryTTS has been around for a very! long time. Version 3.0 even dates back to 2006, long before Deep Learning was a broadly known term and the last official release was version 5.2 in 2016. You can check out this OpenVoice-Tech page to learn more: https://openvoice-tech.net/index.php/MaryTTS ## Why Mary-TTS compatibility is relevant -Due to it's open-source nature, relatively high quality voices and fast synthetization speed Mary-TTS was a popular choice in the past and many tools implemented API support over the years like screen-readers (NVDA + SpeechHub), smart-home HUBs (openHAB, Home Assistant) or voice assistants (Rhasspy, Mycroft, SEPIA). A compatibility layer for Coqui-TTS will ensure that these tools can use Coqui as a drop-in replacement and get even better voices right away. +Due to its open-source nature, relatively high quality voices and fast synthetization speed Mary-TTS was a popular choice in the past and many tools implemented API support over the years like screen-readers (NVDA + SpeechHub), smart-home HUBs (openHAB, Home Assistant) or voice assistants (Rhasspy, Mycroft, SEPIA). A compatibility layer for Coqui-TTS will ensure that these tools can use Coqui as a drop-in replacement and get even better voices right away. ## API and code examples @@ -40,4 +40,4 @@ You can enter the same URLs in your browser and check-out the results there as w ### How it works and limitations A classic Mary-TTS server would usually show all installed locales and voices via the corresponding endpoints and accept the parameters `LOCALE` and `VOICE` for processing. For Coqui-TTS we usually start the server with one specific locale and model and thus cannot return all available options. Instead we return the active locale and use the model name as "voice". Since we only have one active model and always want to return a WAV-file, we currently ignore all other processing parameters except `INPUT_TEXT`. Since the gender is not defined for models in Coqui-TTS we always return `u` (undefined). -We think that this is an acceptable compromise, since users are often only interested in one specific voice anyways, but the API might get extended in the future to support multiple languages and voices at the same time. \ No newline at end of file +We think that this is an acceptable compromise, since users are often only interested in one specific voice anyways, but the API might get extended in the future to support multiple languages and voices at the same time. diff --git a/docs/source/models/tortoise.md b/docs/source/models/tortoise.md index 2df6da7649..1a8e9ca8e9 100644 --- a/docs/source/models/tortoise.md +++ b/docs/source/models/tortoise.md @@ -1,6 +1,6 @@ # 🐢 Tortoise Tortoise is a very expressive TTS system with impressive voice cloning capabilities. It is based on an GPT like autogressive acoustic model that converts input -text to discritized acouistic tokens, a diffusion model that converts these tokens to melspeectrogram frames and a Univnet vocoder to convert the spectrograms to +text to discritized acoustic tokens, a diffusion model that converts these tokens to melspectrogram frames and a Univnet vocoder to convert the spectrograms to the final audio signal. The important downside is that Tortoise is very slow compared to the parallel TTS models like VITS. Big thanks to 👑[@manmay-nakhashi](https://github.com/manmay-nakhashi) who helped us implement Tortoise in 🐸TTS. diff --git a/docs/source/models/xtts.md b/docs/source/models/xtts.md index 7e461a49ff..acb73114b3 100644 --- a/docs/source/models/xtts.md +++ b/docs/source/models/xtts.md @@ -81,42 +81,6 @@ tts.tts_to_file(text="It took me quite a long time to develop a voice, and now t language="en") ``` -##### Streaming inference - -XTTS supports streaming inference. This is useful for real-time applications. - -```python -import os -import time -import torch -import torchaudio - -print("Loading model...") -tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=True) -model = tts.synthesizer.tts_model - -print("Computing speaker latents...") -gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=["reference.wav"]) - -print("Inference...") -t0 = time.time() -stream_generator = model.inference_stream( - "It took me quite a long time to develop a voice and now that I have it I am not going to be silent.", - "en", - gpt_cond_latent, - speaker_embedding -) - -wav_chuncks = [] -for i, chunk in enumerate(stream_generator): - if i == 0: - print(f"Time to first chunck: {time.time() - t0}") - print(f"Received chunk {i} of audio length {chunk.shape[-1]}") - wav_chuncks.append(chunk) -wav = torch.cat(wav_chuncks, dim=0) -torchaudio.save("xtts_streaming.wav", wav.squeeze().unsqueeze(0).cpu(), 24000) -``` - #### 🐸TTS Command line ##### Single reference @@ -150,14 +114,32 @@ or for all wav files in a directory you can use: To use the model API, you need to download the model files and pass config and model file paths manually. -##### Calling manually +#### Manual Inference -If you want to be able to run with `use_deepspeed=True` and **enjoy the speedup**, you need to install deepspeed first. +If you want to be able to `load_checkpoint` with `use_deepspeed=True` and **enjoy the speedup**, you need to install deepspeed first. ```console pip install deepspeed==0.10.3 ``` +##### inference parameters + +- `text`: The text to be synthesized. +- `language`: The language of the text to be synthesized. +- `gpt_cond_latent`: The latent vector you get with get_conditioning_latents. (You can cache for faster inference with same speaker) +- `speaker_embedding`: The speaker embedding you get with get_conditioning_latents. (You can cache for faster inference with same speaker) +- `temperature`: The softmax temperature of the autoregressive model. Defaults to 0.65. +- `length_penalty`: A length penalty applied to the autoregressive decoder. Higher settings causes the model to produce more terse outputs. Defaults to 1.0. +- `repetition_penalty`: A penalty that prevents the autoregressive decoder from repeating itself during decoding. Can be used to reduce the incidence of long silences or "uhhhhhhs", etc. Defaults to 2.0. +- `top_k`: Lower values mean the decoder produces more "likely" (aka boring) outputs. Defaults to 50. +- `top_p`: Lower values mean the decoder produces more "likely" (aka boring) outputs. Defaults to 0.8. +- `speed`: The speed rate of the generated audio. Defaults to 1.0. (can produce artifacts if far from 1.0) +- `enable_text_splitting`: Whether to split the text into sentences and generate audio for each sentence. It allows you to have infinite input length but might loose important context between sentences. Defaults to True. + + +##### Inference + + ```python import os import torch @@ -233,6 +215,50 @@ torchaudio.save("xtts_streaming.wav", wav.squeeze().unsqueeze(0).cpu(), 24000) ### Training +#### Easy training +To make `XTTS_v2` GPT encoder training easier for beginner users we did a gradio demo that implements the whole fine-tuning pipeline. The gradio demo enables the user to easily do the following steps: + +- Preprocessing of the uploaded audio or audio files in 🐸 TTS coqui formatter +- Train the XTTS GPT encoder with the processed data +- Inference support using the fine-tuned model + +The user can run this gradio demo locally or remotely using a Colab Notebook. + +##### Run demo on Colab +To make the `XTTS_v2` fine-tuning more accessible for users that do not have good GPUs available we did a Google Colab Notebook. + +The Colab Notebook is available [here](https://colab.research.google.com/drive/1GiI4_X724M8q2W-zZ-jXo7cWTV7RfaH-?usp=sharing). + +To learn how to use this Colab Notebook please check the [XTTS fine-tuning video](). + +If you are not able to acess the video you need to follow the steps: + +1. Open the Colab notebook and start the demo by runining the first two cells (ignore pip install errors in the first one). +2. Click on the link "Running on public URL:" on the second cell output. +3. On the first Tab (1 - Data processing) you need to select the audio file or files, wait for upload, and then click on the button "Step 1 - Create dataset" and then wait until the dataset processing is done. +4. Soon as the dataset processing is done you need to go to the second Tab (2 - Fine-tuning XTTS Encoder) and press the button "Step 2 - Run the training" and then wait until the training is finished. Note that it can take up to 40 minutes. +5. Soon the training is done you can go to the third Tab (3 - Inference) and then click on the button "Step 3 - Load Fine-tuned XTTS model" and wait until the fine-tuned model is loaded. Then you can do the inference on the model by clicking on the button "Step 4 - Inference". + + +##### Run demo locally + +To run the demo locally you need to do the following steps: +1. Install 🐸 TTS following the instructions available [here](https://tts.readthedocs.io/en/dev/installation.html#installation). +2. Install the Gradio demo requirements with the command `python3 -m pip install -r TTS/demos/xtts_ft_demo/requirements.txt` +3. Run the Gradio demo using the command `python3 TTS/demos/xtts_ft_demo/xtts_demo.py` +4. Follow the steps presented in the [tutorial video](https://www.youtube.com/watch?v=8tpDiiouGxc&feature=youtu.be) to be able to fine-tune and test the fine-tuned model. + + +If you are not able to access the video, here is what you need to do: + +1. On the first Tab (1 - Data processing) select the audio file or files, wait for upload +2. Click on the button "Step 1 - Create dataset" and then wait until the dataset processing is done. +3. Go to the second Tab (2 - Fine-tuning XTTS Encoder) and press the button "Step 2 - Run the training" and then wait until the training is finished. it will take some time. +4. Go to the third Tab (3 - Inference) and then click on the button "Step 3 - Load Fine-tuned XTTS model" and wait until the fine-tuned model is loaded. +5. Now you can run inference with the model by clicking on the button "Step 4 - Inference". + +#### Advanced training + A recipe for `XTTS_v2` GPT encoder training using `LJSpeech` dataset is available at https://github.com/coqui-ai/TTS/tree/dev/recipes/ljspeech/xtts_v1/train_gpt_xtts.py You need to change the fields of the `BaseDatasetConfig` to match your dataset and then update `GPTArgs` and `GPTTrainerConfig` fields as you need. By default, it will use the same parameters that XTTS v1.1 model was trained with. To speed up the model convergence, as default, it will also download the XTTS v1.1 checkpoint and load it. @@ -280,6 +306,7 @@ torchaudio.save(OUTPUT_WAV_PATH, torch.tensor(out["wav"]).unsqueeze(0), 24000) ``` + ## References and Acknowledgements - VallE: https://arxiv.org/abs/2301.02111 - Tortoise Repo: https://github.com/neonbjb/tortoise-tts