diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 7585e4a488..00abd007c7 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -726,27 +726,41 @@ def init_for_audio_edit(self, dvae_checkpoint, mel_norm_file): @torch.inference_mode() def audio_edit( - self, - text_left, - text_edit, - text_right, - audio_left_path, - audio_right_path, - language, - gpt_cond_latent, - speaker_embedding, - generate_times=5, - # GPT inference - temperature=0.75, - length_penalty=1.0, - repetition_penalty=10.0, - top_k=50, - top_p=0.85, - do_sample=True, - num_beams=1, - speed=1.0, - enable_text_splitting=False, - **hf_generate_kwargs,): + self, + text_left, + text_edit, + text_right, + audio_left_path, + audio_right_path, + language, + gpt_cond_latent, + speaker_embedding, + generate_times=5, + # GPT inference + temperature=0.75, + length_penalty=1.0, + repetition_penalty=10.0, + top_k=50, + top_p=0.85, + do_sample=True, + num_beams=1, + speed=1.0, + enable_text_splitting=False, + **hf_generate_kwargs, ): + """ + This function produces an audio clip of the given text_edit between audio_left and audio_rigth. + + Args: + text_left: (str) Text of the left audio. + text_edit: (str) Text of the audio you want to edit. + text_right: (str) Text of the right audio. + audio_left_path: (str) Path of the left auido file. + audio_right_path: (str) Path of the right auido file. + generate_times: (int) The number of times the audio is generated.And choose the best audio. + + Returns: + Generated audio clip(s) as a torch tensor. + """ load_sr: int = 22050 audio_left = load_audio(audio_left_path, load_sr) @@ -766,27 +780,47 @@ def audio_edit( wavs = [] with torch.no_grad(): - sent = text_left + text_edit - sent = sent.strip().lower() - text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device) - - sent_left = text_left.strip().lower() - text_left_tokens = torch.IntTensor(self.tokenizer.encode(sent_left, lang=language)).unsqueeze(0).to(self.device) - - sent_right = text_edit + text_right - sent_right = sent_right.strip().lower() - text_right_tokens = torch.IntTensor(self.tokenizer.encode(sent_right, lang=language)).unsqueeze(0).to(self.device) + text_left = text_left.strip().lower() + text_right = text_right.strip().lower() + text_edit = text_edit.strip().lower() + + sent_left_edit = text_left + text_edit + sent_right_edit = text_edit + text_right + + text_left_edit_tokens = torch.IntTensor( + self.tokenizer.encode(sent_left_edit, lang=language)).unsqueeze(0).to(self.device) + text_right_edit_tokens = torch.IntTensor( + self.tokenizer.encode(sent_right_edit, lang=language)).unsqueeze(0).to(self.device) + text_left_tokens = torch.IntTensor( + self.tokenizer.encode(text_left, lang=language)).unsqueeze(0).to(self.device) + text_right_tokens = torch.IntTensor( + self.tokenizer.encode(text_right, lang=language)).unsqueeze(0).to(self.device) assert ( - text_tokens.shape[-1] < self.args.gpt_max_text_tokens + text_left_edit_tokens.shape[-1] < self.args.gpt_max_text_tokens ), " ❗ XTTS can only generate text with a maximum of 400 tokens." - gpt_inputs = self.gpt.compute_embeddings(gpt_cond_latent, text_tokens) + # calculate right audio latents + expected_right_output_len = torch.tensor( + [codes_right.shape[-1] * self.gpt.code_stride_len], device=text_right_tokens.device + ) + text_right_len = torch.tensor([text_right_tokens.shape[-1]], device=self.device) + right_latents = self.gpt( + text_right_tokens, + text_right_len, + codes_right, + expected_right_output_len, + cond_latents=gpt_cond_latent, + return_attentions=False, + return_latent=True, + ) + + gpt_inputs = self.gpt.compute_embeddings(gpt_cond_latent, text_left_edit_tokens) # add codes_left to inputs gpt_inputs = torch.cat([gpt_inputs, codes_left], dim=1) best_gpt_latents = None - best_logits = 0 + best_loss = float("inf") for _ in range(generate_times): gen = self.gpt.gpt_inference.generate( gpt_inputs, @@ -809,38 +843,34 @@ def audio_edit( gen_len = gpt_codes.shape[-1] - gpt_right_codes = torch.cat([gpt_codes, codes_right], dim=1) - expected_output_len = torch.tensor( - [gpt_right_codes.shape[-1] * self.gpt.code_stride_len], device=text_right_tokens.device + gpt_right_edit_codes = torch.cat([gpt_codes, codes_right], dim=1) + expected_right_edit_output_len = torch.tensor( + [gpt_right_edit_codes.shape[-1] * self.gpt.code_stride_len], device=text_right_edit_tokens.device ) - text_len = torch.tensor([text_right_tokens.shape[-1]], device=self.device) + text_right_edit_len = torch.tensor([text_right_edit_tokens.shape[-1]], device=self.device) - gpt_latents = self.gpt( - text_right_tokens, - text_len, - gpt_right_codes, - expected_output_len, + right_edit_latents = self.gpt( + text_right_edit_tokens, + text_right_edit_len, + gpt_right_edit_codes, + expected_right_edit_output_len, cond_latents=gpt_cond_latent, return_attentions=False, return_latent=True, ) + right_fake_latents = right_edit_latents[:, gen_len:] - right_logits = torch.gather(F.softmax(gpt_latents[:, gen_len:], dim=-1), - -1, - codes_right.unsqueeze(0).transpose(-1, -2)) - sum_logits = torch.sum(right_logits, dim=-2) - sum_logits = float(sum_logits[0,0]) - - if sum_logits > best_logits: - best_logits = sum_logits - best_gpt_latents = gpt_latents.detach() - + # try to choose the best result + l1 = F.l1_loss(right_fake_latents, right_latents) + if l1 < best_loss: + best_loss = l1 + best_gpt_latents = right_edit_latents[:, : gen_len].detach() expected_left_len = torch.tensor( [codes_left.shape[-1] * self.gpt.code_stride_len], device=text_left_tokens.device ) text_left_len = torch.tensor([text_left_tokens.shape[-1]], device=self.device) - gpt_left_latents = self.gpt( + left_latents = self.gpt( text_left_tokens, text_left_len, codes_left, @@ -850,15 +880,15 @@ def audio_edit( return_latent=True, ) - # if length_scale != 1.0: - # best_gpt_latents = F.interpolate( - # best_gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear" - # ).transpose(1, 2) + if length_scale != 1.0: + best_gpt_latents = F.interpolate( + best_gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear" + ).transpose(1, 2) # wavs.append(load_audio(audio_left_path, self.args.output_sample_rate).cpu().squeeze()) # wavs.append(self.hifigan_decoder(best_gpt_latents, g=speaker_embedding).cpu().squeeze()) # wavs.append(load_audio(audio_right_path, self.args.output_sample_rate).cpu().squeeze()) - wavs.append(self.hifigan_decoder(torch.cat([gpt_left_latents, best_gpt_latents], dim=1), + wavs.append(self.hifigan_decoder(torch.cat([left_latents, best_gpt_latents, right_latents], dim=1), g=speaker_embedding).cpu().squeeze()) return {