Skip to content

Commit

Permalink
UPDATE use l1_loss to choose best result
Browse files Browse the repository at this point in the history
  • Loading branch information
aaron-lii committed Jan 5, 2024
1 parent 3559581 commit 3f36c4c
Showing 1 changed file with 89 additions and 59 deletions.
148 changes: 89 additions & 59 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,27 +726,41 @@ def init_for_audio_edit(self, dvae_checkpoint, mel_norm_file):

@torch.inference_mode()
def audio_edit(
self,
text_left,
text_edit,
text_right,
audio_left_path,
audio_right_path,
language,
gpt_cond_latent,
speaker_embedding,
generate_times=5,
# GPT inference
temperature=0.75,
length_penalty=1.0,
repetition_penalty=10.0,
top_k=50,
top_p=0.85,
do_sample=True,
num_beams=1,
speed=1.0,
enable_text_splitting=False,
**hf_generate_kwargs,):
self,
text_left,
text_edit,
text_right,
audio_left_path,
audio_right_path,
language,
gpt_cond_latent,
speaker_embedding,
generate_times=5,
# GPT inference
temperature=0.75,
length_penalty=1.0,
repetition_penalty=10.0,
top_k=50,
top_p=0.85,
do_sample=True,
num_beams=1,
speed=1.0,
enable_text_splitting=False,
**hf_generate_kwargs, ):
"""
This function produces an audio clip of the given text_edit between audio_left and audio_rigth.
Args:
text_left: (str) Text of the left audio.
text_edit: (str) Text of the audio you want to edit.
text_right: (str) Text of the right audio.
audio_left_path: (str) Path of the left auido file.
audio_right_path: (str) Path of the right auido file.
generate_times: (int) The number of times the audio is generated.And choose the best audio.
Returns:
Generated audio clip(s) as a torch tensor.
"""

load_sr: int = 22050
audio_left = load_audio(audio_left_path, load_sr)
Expand All @@ -766,27 +780,47 @@ def audio_edit(

wavs = []
with torch.no_grad():
sent = text_left + text_edit
sent = sent.strip().lower()
text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device)

sent_left = text_left.strip().lower()
text_left_tokens = torch.IntTensor(self.tokenizer.encode(sent_left, lang=language)).unsqueeze(0).to(self.device)

sent_right = text_edit + text_right
sent_right = sent_right.strip().lower()
text_right_tokens = torch.IntTensor(self.tokenizer.encode(sent_right, lang=language)).unsqueeze(0).to(self.device)
text_left = text_left.strip().lower()
text_right = text_right.strip().lower()
text_edit = text_edit.strip().lower()

sent_left_edit = text_left + text_edit
sent_right_edit = text_edit + text_right

text_left_edit_tokens = torch.IntTensor(
self.tokenizer.encode(sent_left_edit, lang=language)).unsqueeze(0).to(self.device)
text_right_edit_tokens = torch.IntTensor(
self.tokenizer.encode(sent_right_edit, lang=language)).unsqueeze(0).to(self.device)
text_left_tokens = torch.IntTensor(
self.tokenizer.encode(text_left, lang=language)).unsqueeze(0).to(self.device)
text_right_tokens = torch.IntTensor(
self.tokenizer.encode(text_right, lang=language)).unsqueeze(0).to(self.device)

assert (
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
text_left_edit_tokens.shape[-1] < self.args.gpt_max_text_tokens
), " ❗ XTTS can only generate text with a maximum of 400 tokens."

gpt_inputs = self.gpt.compute_embeddings(gpt_cond_latent, text_tokens)
# calculate right audio latents
expected_right_output_len = torch.tensor(
[codes_right.shape[-1] * self.gpt.code_stride_len], device=text_right_tokens.device
)
text_right_len = torch.tensor([text_right_tokens.shape[-1]], device=self.device)
right_latents = self.gpt(
text_right_tokens,
text_right_len,
codes_right,
expected_right_output_len,
cond_latents=gpt_cond_latent,
return_attentions=False,
return_latent=True,
)

gpt_inputs = self.gpt.compute_embeddings(gpt_cond_latent, text_left_edit_tokens)
# add codes_left to inputs
gpt_inputs = torch.cat([gpt_inputs, codes_left], dim=1)

best_gpt_latents = None
best_logits = 0
best_loss = float("inf")
for _ in range(generate_times):
gen = self.gpt.gpt_inference.generate(
gpt_inputs,
Expand All @@ -809,38 +843,34 @@ def audio_edit(

gen_len = gpt_codes.shape[-1]

gpt_right_codes = torch.cat([gpt_codes, codes_right], dim=1)
expected_output_len = torch.tensor(
[gpt_right_codes.shape[-1] * self.gpt.code_stride_len], device=text_right_tokens.device
gpt_right_edit_codes = torch.cat([gpt_codes, codes_right], dim=1)
expected_right_edit_output_len = torch.tensor(
[gpt_right_edit_codes.shape[-1] * self.gpt.code_stride_len], device=text_right_edit_tokens.device
)
text_len = torch.tensor([text_right_tokens.shape[-1]], device=self.device)
text_right_edit_len = torch.tensor([text_right_edit_tokens.shape[-1]], device=self.device)

gpt_latents = self.gpt(
text_right_tokens,
text_len,
gpt_right_codes,
expected_output_len,
right_edit_latents = self.gpt(
text_right_edit_tokens,
text_right_edit_len,
gpt_right_edit_codes,
expected_right_edit_output_len,
cond_latents=gpt_cond_latent,
return_attentions=False,
return_latent=True,
)
right_fake_latents = right_edit_latents[:, gen_len:]

right_logits = torch.gather(F.softmax(gpt_latents[:, gen_len:], dim=-1),
-1,
codes_right.unsqueeze(0).transpose(-1, -2))
sum_logits = torch.sum(right_logits, dim=-2)
sum_logits = float(sum_logits[0,0])

if sum_logits > best_logits:
best_logits = sum_logits
best_gpt_latents = gpt_latents.detach()

# try to choose the best result
l1 = F.l1_loss(right_fake_latents, right_latents)
if l1 < best_loss:
best_loss = l1
best_gpt_latents = right_edit_latents[:, : gen_len].detach()

expected_left_len = torch.tensor(
[codes_left.shape[-1] * self.gpt.code_stride_len], device=text_left_tokens.device
)
text_left_len = torch.tensor([text_left_tokens.shape[-1]], device=self.device)
gpt_left_latents = self.gpt(
left_latents = self.gpt(
text_left_tokens,
text_left_len,
codes_left,
Expand All @@ -850,15 +880,15 @@ def audio_edit(
return_latent=True,
)

# if length_scale != 1.0:
# best_gpt_latents = F.interpolate(
# best_gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
# ).transpose(1, 2)
if length_scale != 1.0:
best_gpt_latents = F.interpolate(
best_gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
).transpose(1, 2)

# wavs.append(load_audio(audio_left_path, self.args.output_sample_rate).cpu().squeeze())
# wavs.append(self.hifigan_decoder(best_gpt_latents, g=speaker_embedding).cpu().squeeze())
# wavs.append(load_audio(audio_right_path, self.args.output_sample_rate).cpu().squeeze())
wavs.append(self.hifigan_decoder(torch.cat([gpt_left_latents, best_gpt_latents], dim=1),
wavs.append(self.hifigan_decoder(torch.cat([left_latents, best_gpt_latents, right_latents], dim=1),
g=speaker_embedding).cpu().squeeze())

return {
Expand Down

0 comments on commit 3f36c4c

Please sign in to comment.