Skip to content

Commit

Permalink
fix: retain backwards compatability in functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Jake Tae authored and Jake Tae committed Aug 10, 2023
1 parent 1550316 commit 3fb5f47
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions TTS/tts/utils/synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@
from torch import nn


def numpy_to_torch(np_array, dtype, device="cpu"):
def numpy_to_torch(np_array, dtype, cuda=False, device="cpu"):
if cuda:
device = "cuda"
if np_array is None:
return None
tensor = torch.as_tensor(np_array, dtype=dtype, device=device)
return tensor


def compute_style_mel(style_wav, ap, device="cpu"):
def compute_style_mel(style_wav, ap, cuda=False, device="cpu"):
if cuda:
device = "cuda"
style_mel = torch.FloatTensor(
ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate)), device=device,
).unsqueeze(0)
Expand Down Expand Up @@ -71,14 +75,18 @@ def inv_spectrogram(postnet_output, ap, CONFIG):
return wav


def id_to_torch(aux_id, device="cpu"):
def id_to_torch(aux_id, cuda=False, device="cpu"):
if cuda:
device = "cuda"
if aux_id is not None:
aux_id = np.asarray(aux_id)
aux_id = torch.from_numpy(aux_id).to(device)
return aux_id


def embedding_to_torch(d_vector, device="cpu"):
def embedding_to_torch(d_vector, cuda=False, device="cpu"):
if cuda:
device = "cuda"
if d_vector is not None:
d_vector = np.asarray(d_vector)
d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor)
Expand Down

0 comments on commit 3fb5f47

Please sign in to comment.