diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index e0ffbd71ed..d0392fc943 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -358,7 +358,10 @@ def tts( if speaker_wav is not None and self.tts_model.speaker_manager is not None: speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav) + vocoder_device = "cpu" use_gl = self.vocoder_model is None + if not use_gl: + vocoder_device = next(self.vocoder_model.parameters()).device if not reference_wav: # not voice conversion for sen in sens: @@ -390,7 +393,6 @@ def tts( mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy() # denormalize tts output based on tts audio config mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T - device_type = "cuda" if self.use_cuda else "cpu" # renormalize spectrogram based on vocoder config vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T) # compute scale factor for possible sample rate mismatch @@ -405,7 +407,7 @@ def tts( vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable # run vocoder model # [1, T, C] - waveform = self.vocoder_model.inference(vocoder_input.to(device_type)) + waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device)) if self.use_cuda and not use_gl: waveform = waveform.cpu() if not use_gl: @@ -455,7 +457,6 @@ def tts( mel_postnet_spec = outputs[0].detach().cpu().numpy() # denormalize tts output based on tts audio config mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T - device_type = "cuda" if self.use_cuda else "cpu" # renormalize spectrogram based on vocoder config vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T) # compute scale factor for possible sample rate mismatch @@ -470,9 +471,8 @@ def tts( vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable # run vocoder model # [1, T, C] - waveform = self.vocoder_model.inference(vocoder_input.to(device_type)) - if self.use_cuda: - waveform = waveform.cpu() + waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device)) + waveform = waveform.cpu() if not use_gl: waveform = waveform.numpy() wavs = waveform.squeeze()