From 3fb5f4709c9c5cf9ecf5983c532909024cd81260 Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Thu, 10 Aug 2023 14:59:52 -0400 Subject: [PATCH] fix: retain backwards compatability in functions --- TTS/tts/utils/synthesis.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index f4133ffc34..d87d5abda6 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -5,14 +5,18 @@ from torch import nn -def numpy_to_torch(np_array, dtype, device="cpu"): +def numpy_to_torch(np_array, dtype, cuda=False, device="cpu"): + if cuda: + device = "cuda" if np_array is None: return None tensor = torch.as_tensor(np_array, dtype=dtype, device=device) return tensor -def compute_style_mel(style_wav, ap, device="cpu"): +def compute_style_mel(style_wav, ap, cuda=False, device="cpu"): + if cuda: + device = "cuda" style_mel = torch.FloatTensor( ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate)), device=device, ).unsqueeze(0) @@ -71,14 +75,18 @@ def inv_spectrogram(postnet_output, ap, CONFIG): return wav -def id_to_torch(aux_id, device="cpu"): +def id_to_torch(aux_id, cuda=False, device="cpu"): + if cuda: + device = "cuda" if aux_id is not None: aux_id = np.asarray(aux_id) aux_id = torch.from_numpy(aux_id).to(device) return aux_id -def embedding_to_torch(d_vector, device="cpu"): +def embedding_to_torch(d_vector, cuda=False, device="cpu"): + if cuda: + device = "cuda" if d_vector is not None: d_vector = np.asarray(d_vector) d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor)