Skip to content

Commit

Permalink
refactor(audio.processor): remove duplicate quantization methods
Browse files Browse the repository at this point in the history
  • Loading branch information
eginhard committed Nov 15, 2023
1 parent ddbaecd commit 8f1db75
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 45 deletions.
3 changes: 2 additions & 1 deletion TTS/bin/extract_tts_spectrograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import quantize
from TTS.utils.generic_utils import count_parameters

use_cuda = torch.cuda.is_available()
Expand Down Expand Up @@ -197,7 +198,7 @@ def extract_spectrograms(

# quantize and save wav
if quantize_bits > 0:
wavq = ap.quantize(wav, quantize_bits)
wavq = quantize(wav, quantize_bits)
np.save(wavq_path, wavq)

# save TTS mel
Expand Down
40 changes: 0 additions & 40 deletions TTS/utils/audio/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,43 +631,3 @@ def get_duration(self, filename: str) -> float:
filename (str): Path to the wav file.
"""
return librosa.get_duration(filename=filename)

@staticmethod
def mulaw_encode(wav: np.ndarray, qc: int) -> np.ndarray:
mu = 2**qc - 1
# wav_abs = np.minimum(np.abs(wav), 1.0)
signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu)
# Quantize signal to the specified number of levels.
signal = (signal + 1) / 2 * mu + 0.5
return np.floor(
signal,
)

@staticmethod
def mulaw_decode(wav, qc):
"""Recovers waveform from quantized values."""
mu = 2**qc - 1
x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
return x

@staticmethod
def encode_16bits(x):
return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16)

@staticmethod
def quantize(x: np.ndarray, bits: int) -> np.ndarray:
"""Quantize a waveform to a given number of bits.
Args:
x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`.
bits (int): Number of quantization bits.
Returns:
np.ndarray: Quantized waveform.
"""
return (x + 1.0) * (2**bits - 1) / 2

@staticmethod
def dequantize(x, bits):
"""Dequantize a waveform from the given number of bits."""
return 2 * x / (2**bits - 1) - 1
7 changes: 6 additions & 1 deletion TTS/vocoder/datasets/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tqdm import tqdm

from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize


def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor):
Expand All @@ -29,7 +30,11 @@ def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor):
mel = ap.melspectrogram(y)
np.save(mel_path, mel)
if isinstance(config.mode, int):
quant = ap.mulaw_encode(y, qc=config.mode) if config.model_args.mulaw else ap.quantize(y, bits=config.mode)
quant = (
mulaw_encode(wav=y, mulaw_qc=config.mode)
if config.model_args.mulaw
else quantize(x=y, quantize_bits=config.mode)
)
np.save(quant_path, quant)


Expand Down
6 changes: 5 additions & 1 deletion TTS/vocoder/datasets/wavernn_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import torch
from torch.utils.data import Dataset

from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize


class WaveRNNDataset(Dataset):
"""
Expand Down Expand Up @@ -66,7 +68,9 @@ def load_item(self, index):
x_input = audio
elif isinstance(self.mode, int):
x_input = (
self.ap.mulaw_encode(audio, qc=self.mode) if self.mulaw else self.ap.quantize(audio, bits=self.mode)
mulaw_encode(wav=audio, mulaw_qc=self.mode)
if self.mulaw
else quantize(x=audio, quantize_bits=self.mode)
)
else:
raise RuntimeError("Unknown dataset mode - ", self.mode)
Expand Down
3 changes: 2 additions & 1 deletion TTS/vocoder/models/wavernn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from TTS.tts.utils.visual import plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import mulaw_decode
from TTS.utils.io import load_fsspec
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
from TTS.vocoder.layers.losses import WaveRNNLoss
Expand Down Expand Up @@ -399,7 +400,7 @@ def inference(self, mels, batched=None, target=None, overlap=None):
output = output[0]

if self.args.mulaw and isinstance(self.args.mode, int):
output = AudioProcessor.mulaw_decode(output, self.args.mode)
output = mulaw_decode(wav=output, mulaw_qc=self.args.mode)

# Fade-out at the end to avoid signal cutting out suddenly
fade_out = np.linspace(1, 0, 20 * self.config.audio.hop_length)
Expand Down
3 changes: 2 additions & 1 deletion notebooks/ExtractTTSpectrogram.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"from TTS.tts.utils.text.tokenizer import TTSTokenizer\n",
"from TTS.tts.utils.visual import plot_spectrogram\n",
"from TTS.utils.audio import AudioProcessor\n",
"from TTS.utils.audio.numpy_transforms import quantize\n",
"\n",
"%matplotlib inline\n",
"\n",
Expand Down Expand Up @@ -190,7 +191,7 @@
"\n",
" # quantize and save wav\n",
" if QUANTIZE_BITS > 0:\n",
" wavq = ap.quantize(wav, QUANTIZE_BITS)\n",
" wavq = quantize(wav, QUANTIZE_BITS)\n",
" np.save(wavq_path, wavq)\n",
"\n",
" # save TTS mel\n",
Expand Down

0 comments on commit 8f1db75

Please sign in to comment.