Skip to content

Commit

Permalink
feature: use device for vocoder
Browse files Browse the repository at this point in the history
  • Loading branch information
Jake Tae authored and Jake Tae committed Aug 10, 2023
1 parent b64a259 commit 3272b0b
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions TTS/utils/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,10 @@ def tts(
if speaker_wav is not None and self.tts_model.speaker_manager is not None:
speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav)

vocoder_device = "cpu"
use_gl = self.vocoder_model is None
if not use_gl:
vocoder_device = next(self.vocoder_model.parameters()).device

if not reference_wav: # not voice conversion
for sen in sens:
Expand Down Expand Up @@ -390,7 +393,6 @@ def tts(
mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy()
# denormalize tts output based on tts audio config
mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T
device_type = "cuda" if self.use_cuda else "cpu"
# renormalize spectrogram based on vocoder config
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
# compute scale factor for possible sample rate mismatch
Expand All @@ -405,7 +407,7 @@ def tts(
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
# run vocoder model
# [1, T, C]
waveform = self.vocoder_model.inference(vocoder_input.to(device_type))
waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device))
if self.use_cuda and not use_gl:
waveform = waveform.cpu()
if not use_gl:
Expand Down Expand Up @@ -455,7 +457,6 @@ def tts(
mel_postnet_spec = outputs[0].detach().cpu().numpy()
# denormalize tts output based on tts audio config
mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T
device_type = "cuda" if self.use_cuda else "cpu"
# renormalize spectrogram based on vocoder config
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
# compute scale factor for possible sample rate mismatch
Expand All @@ -470,9 +471,8 @@ def tts(
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
# run vocoder model
# [1, T, C]
waveform = self.vocoder_model.inference(vocoder_input.to(device_type))
if self.use_cuda:
waveform = waveform.cpu()
waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device))
waveform = waveform.cpu()
if not use_gl:
waveform = waveform.numpy()
wavs = waveform.squeeze()
Expand Down

0 comments on commit 3272b0b

Please sign in to comment.