Skip to content

Commit

Permalink
feature: support .to(device) in tts and synthesizer
Browse files Browse the repository at this point in the history
  • Loading branch information
Jake Tae authored and Jake Tae committed Aug 10, 2023
1 parent c87377b commit 1550316
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 26 deletions.
8 changes: 7 additions & 1 deletion TTS/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import os
import tempfile
import urllib.request
import warnings
from pathlib import Path
from typing import Tuple, Union

import numpy as np
import requests
from scipy.io import wavfile
from torch import nn

from TTS.utils.audio.numpy_transforms import save_wav
from TTS.utils.manage import ModelManager
Expand Down Expand Up @@ -229,7 +231,7 @@ def tts_to_file(
return file_path


class TTS:
class TTS(nn.Module):
"""TODO: Add voice conversion and Capacitron support."""

def __init__(
Expand Down Expand Up @@ -277,13 +279,17 @@ def __init__(
progress_bar (bool, optional): Whether to pring a progress bar while downloading a model. Defaults to True.
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
"""
super().__init__()
self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar, verbose=False)

self.synthesizer = None
self.voice_converter = None
self.csapi = None
self.model_name = None

if gpu:
warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.")

if model_name is not None:
if "tts_models" in model_name or "coqui_studio" in model_name:
self.load_tts_model_by_name(model_name, gpu)
Expand Down
45 changes: 21 additions & 24 deletions TTS/tts/utils/synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,17 @@
from torch import nn


def numpy_to_torch(np_array, dtype, cuda=False):
def numpy_to_torch(np_array, dtype, device="cpu"):
if np_array is None:
return None
tensor = torch.as_tensor(np_array, dtype=dtype)
if cuda:
return tensor.cuda()
tensor = torch.as_tensor(np_array, dtype=dtype, device=device)
return tensor


def compute_style_mel(style_wav, ap, cuda=False):
style_mel = torch.FloatTensor(ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0)
if cuda:
return style_mel.cuda()
def compute_style_mel(style_wav, ap, device="cpu"):
style_mel = torch.FloatTensor(
ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate)), device=device,
).unsqueeze(0)
return style_mel


Expand Down Expand Up @@ -73,22 +71,18 @@ def inv_spectrogram(postnet_output, ap, CONFIG):
return wav


def id_to_torch(aux_id, cuda=False):
def id_to_torch(aux_id, device="cpu"):
if aux_id is not None:
aux_id = np.asarray(aux_id)
aux_id = torch.from_numpy(aux_id)
if cuda:
return aux_id.cuda()
aux_id = torch.from_numpy(aux_id).to(device)
return aux_id


def embedding_to_torch(d_vector, cuda=False):
def embedding_to_torch(d_vector, device="cpu"):
if d_vector is not None:
d_vector = np.asarray(d_vector)
d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor)
d_vector = d_vector.squeeze().unsqueeze(0)
if cuda:
return d_vector.cuda()
d_vector = d_vector.squeeze().unsqueeze(0).to(device)
return d_vector


Expand Down Expand Up @@ -162,17 +156,20 @@ def synthesis(
language_id (int):
Language ID passed to the language embedding layer in multi-langual model. Defaults to None.
"""
# device
device = next(model.parameters()).device

# GST or Capacitron processing
# TODO: need to handle the case of setting both gst and capacitron to true somewhere
style_mel = None
if CONFIG.has("gst") and CONFIG.gst and style_wav is not None:
if isinstance(style_wav, dict):
style_mel = style_wav
else:
style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda)
style_mel = compute_style_mel(style_wav, model.ap, device=device)

if CONFIG.has("capacitron_vae") and CONFIG.use_capacitron_vae and style_wav is not None:
style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda)
style_mel = compute_style_mel(style_wav, model.ap, device=device)
style_mel = style_mel.transpose(1, 2) # [1, time, depth]

language_name = None
Expand All @@ -188,26 +185,26 @@ def synthesis(
)
# pass tensors to backend
if speaker_id is not None:
speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
speaker_id = id_to_torch(speaker_id, device=device)

if d_vector is not None:
d_vector = embedding_to_torch(d_vector, cuda=use_cuda)
d_vector = embedding_to_torch(d_vector, device=device)

if language_id is not None:
language_id = id_to_torch(language_id, cuda=use_cuda)
language_id = id_to_torch(language_id, device=device)

if not isinstance(style_mel, dict):
# GST or Capacitron style mel
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
style_mel = numpy_to_torch(style_mel, torch.float, device=device)
if style_text is not None:
style_text = np.asarray(
model.tokenizer.text_to_ids(style_text, language=language_id),
dtype=np.int32,
)
style_text = numpy_to_torch(style_text, torch.long, cuda=use_cuda)
style_text = numpy_to_torch(style_text, torch.long, device=device)
style_text = style_text.unsqueeze(0)

text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda)
text_inputs = numpy_to_torch(text_inputs, torch.long, device=device)
text_inputs = text_inputs.unsqueeze(0)
# synthesize voice
outputs = run_model_torch(
Expand Down
4 changes: 3 additions & 1 deletion TTS/utils/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import pysbd
import torch
from torch import nn

from TTS.config import load_config
from TTS.tts.configs.vits_config import VitsConfig
Expand All @@ -21,7 +22,7 @@
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input


class Synthesizer(object):
class Synthesizer(nn.Module):
def __init__(
self,
tts_checkpoint: str = "",
Expand Down Expand Up @@ -60,6 +61,7 @@ def __init__(
vc_config (str, optional): path to the voice conversion config file. Defaults to `""`,
use_cuda (bool, optional): enable/disable cuda. Defaults to False.
"""
super().__init__()
self.tts_checkpoint = tts_checkpoint
self.tts_config_path = tts_config_path
self.tts_speakers_file = tts_speakers_file
Expand Down

0 comments on commit 1550316

Please sign in to comment.