From 155031654947a914b0b5c6201af6abe52e6e4660 Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Thu, 10 Aug 2023 14:31:30 -0400 Subject: [PATCH] feature: support `.to(device)` in tts and synthesizer --- TTS/api.py | 8 ++++++- TTS/tts/utils/synthesis.py | 45 ++++++++++++++++++-------------------- TTS/utils/synthesizer.py | 4 +++- 3 files changed, 31 insertions(+), 26 deletions(-) diff --git a/TTS/api.py b/TTS/api.py index 861c541826..578b038f1f 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -3,12 +3,14 @@ import os import tempfile import urllib.request +import warnings from pathlib import Path from typing import Tuple, Union import numpy as np import requests from scipy.io import wavfile +from torch import nn from TTS.utils.audio.numpy_transforms import save_wav from TTS.utils.manage import ModelManager @@ -229,7 +231,7 @@ def tts_to_file( return file_path -class TTS: +class TTS(nn.Module): """TODO: Add voice conversion and Capacitron support.""" def __init__( @@ -277,6 +279,7 @@ def __init__( progress_bar (bool, optional): Whether to pring a progress bar while downloading a model. Defaults to True. gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. """ + super().__init__() self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar, verbose=False) self.synthesizer = None @@ -284,6 +287,9 @@ def __init__( self.csapi = None self.model_name = None + if gpu: + warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.") + if model_name is not None: if "tts_models" in model_name or "coqui_studio" in model_name: self.load_tts_model_by_name(model_name, gpu) diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 039816db1f..f4133ffc34 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -5,19 +5,17 @@ from torch import nn -def numpy_to_torch(np_array, dtype, cuda=False): +def numpy_to_torch(np_array, dtype, device="cpu"): if np_array is None: return None - tensor = torch.as_tensor(np_array, dtype=dtype) - if cuda: - return tensor.cuda() + tensor = torch.as_tensor(np_array, dtype=dtype, device=device) return tensor -def compute_style_mel(style_wav, ap, cuda=False): - style_mel = torch.FloatTensor(ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0) - if cuda: - return style_mel.cuda() +def compute_style_mel(style_wav, ap, device="cpu"): + style_mel = torch.FloatTensor( + ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate)), device=device, + ).unsqueeze(0) return style_mel @@ -73,22 +71,18 @@ def inv_spectrogram(postnet_output, ap, CONFIG): return wav -def id_to_torch(aux_id, cuda=False): +def id_to_torch(aux_id, device="cpu"): if aux_id is not None: aux_id = np.asarray(aux_id) - aux_id = torch.from_numpy(aux_id) - if cuda: - return aux_id.cuda() + aux_id = torch.from_numpy(aux_id).to(device) return aux_id -def embedding_to_torch(d_vector, cuda=False): +def embedding_to_torch(d_vector, device="cpu"): if d_vector is not None: d_vector = np.asarray(d_vector) d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor) - d_vector = d_vector.squeeze().unsqueeze(0) - if cuda: - return d_vector.cuda() + d_vector = d_vector.squeeze().unsqueeze(0).to(device) return d_vector @@ -162,6 +156,9 @@ def synthesis( language_id (int): Language ID passed to the language embedding layer in multi-langual model. Defaults to None. """ + # device + device = next(model.parameters()).device + # GST or Capacitron processing # TODO: need to handle the case of setting both gst and capacitron to true somewhere style_mel = None @@ -169,10 +166,10 @@ def synthesis( if isinstance(style_wav, dict): style_mel = style_wav else: - style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda) + style_mel = compute_style_mel(style_wav, model.ap, device=device) if CONFIG.has("capacitron_vae") and CONFIG.use_capacitron_vae and style_wav is not None: - style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda) + style_mel = compute_style_mel(style_wav, model.ap, device=device) style_mel = style_mel.transpose(1, 2) # [1, time, depth] language_name = None @@ -188,26 +185,26 @@ def synthesis( ) # pass tensors to backend if speaker_id is not None: - speaker_id = id_to_torch(speaker_id, cuda=use_cuda) + speaker_id = id_to_torch(speaker_id, device=device) if d_vector is not None: - d_vector = embedding_to_torch(d_vector, cuda=use_cuda) + d_vector = embedding_to_torch(d_vector, device=device) if language_id is not None: - language_id = id_to_torch(language_id, cuda=use_cuda) + language_id = id_to_torch(language_id, device=device) if not isinstance(style_mel, dict): # GST or Capacitron style mel - style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda) + style_mel = numpy_to_torch(style_mel, torch.float, device=device) if style_text is not None: style_text = np.asarray( model.tokenizer.text_to_ids(style_text, language=language_id), dtype=np.int32, ) - style_text = numpy_to_torch(style_text, torch.long, cuda=use_cuda) + style_text = numpy_to_torch(style_text, torch.long, device=device) style_text = style_text.unsqueeze(0) - text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda) + text_inputs = numpy_to_torch(text_inputs, torch.long, device=device) text_inputs = text_inputs.unsqueeze(0) # synthesize voice outputs = run_model_torch( diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index bc0e231df0..e0ffbd71ed 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -5,6 +5,7 @@ import numpy as np import pysbd import torch +from torch import nn from TTS.config import load_config from TTS.tts.configs.vits_config import VitsConfig @@ -21,7 +22,7 @@ from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input -class Synthesizer(object): +class Synthesizer(nn.Module): def __init__( self, tts_checkpoint: str = "", @@ -60,6 +61,7 @@ def __init__( vc_config (str, optional): path to the voice conversion config file. Defaults to `""`, use_cuda (bool, optional): enable/disable cuda. Defaults to False. """ + super().__init__() self.tts_checkpoint = tts_checkpoint self.tts_config_path = tts_config_path self.tts_speakers_file = tts_speakers_file