From 5dd1e59007b21f8c2bd05936d5f69393c2693ece Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 31 Jul 2023 11:37:35 +0200 Subject: [PATCH] Add kwargs to ignore extra arguments w/o error (#2822) --- TTS/tts/layers/bark/inference_funcs.py | 55 ++++++++++++++++++++++++-- TTS/tts/models/bark.py | 2 +- 2 files changed, 53 insertions(+), 4 deletions(-) diff --git a/TTS/tts/layers/bark/inference_funcs.py b/TTS/tts/layers/bark/inference_funcs.py index 3a6875ef39..d7f3f79345 100644 --- a/TTS/tts/layers/bark/inference_funcs.py +++ b/TTS/tts/layers/bark/inference_funcs.py @@ -162,8 +162,28 @@ def generate_text_semantic( allow_early_stop=True, base=None, use_kv_caching=True, + **kwargs, # pylint: disable=unused-argument ): - """Generate semantic tokens from text.""" + """Generate semantic tokens from text. + + Args: + text (str): The text to generate semantic tokens from. + model (BarkModel): The BarkModel to use for generating the semantic tokens. + history_prompt (tuple): A tuple of (semantic_history, coarse_history, fine_history) to use as a prompt for the generation. + temp (float): The temperature to use for the generation. + top_k (int): The number of top tokens to consider for the generation. + top_p (float): The cumulative probability to consider for the generation. + silent (bool): Whether to silence the tqdm progress bar. + min_eos_p (float): The minimum probability to consider for the end of sentence token. + max_gen_duration_s (float): The maximum duration in seconds to generate for. + allow_early_stop (bool): Whether to allow the generation to stop early. + base (tuple): A tuple of (semantic_history, coarse_history, fine_history) to use as a base for the generation. + use_kv_caching (bool): Whether to use key-value caching for the generation. + **kwargs: Additional keyword arguments. They are ignored. + + Returns: + np.ndarray: The generated semantic tokens. + """ print(f"history_prompt in gen: {history_prompt}") assert isinstance(text, str) text = _normalize_whitespace(text) @@ -298,7 +318,24 @@ def generate_coarse( base=None, use_kv_caching=True, ): - """Generate coarse audio codes from semantic tokens.""" + """Generate coarse audio codes from semantic tokens. + + Args: + x_semantic (np.ndarray): The semantic tokens to generate coarse audio codes from. + model (BarkModel): The BarkModel to use for generating the coarse audio codes. + history_prompt (tuple): A tuple of (semantic_history, coarse_history, fine_history) to use as a prompt for the generation. + temp (float): The temperature to use for the generation. + top_k (int): The number of top tokens to consider for the generation. + top_p (float): The cumulative probability to consider for the generation. + silent (bool): Whether to silence the tqdm progress bar. + max_coarse_history (int): The maximum number of coarse audio codes to use as history. + sliding_window_len (int): The length of the sliding window to use for the generation. + base (tuple): A tuple of (semantic_history, coarse_history, fine_history) to use as a base for the generation. + use_kv_caching (bool): Whether to use key-value caching for the generation. + + Returns: + np.ndarray: The generated coarse audio codes. + """ assert ( isinstance(x_semantic, np.ndarray) and len(x_semantic.shape) == 1 @@ -453,7 +490,19 @@ def generate_fine( silent=True, base=None, ): - """Generate full audio codes from coarse audio codes.""" + """Generate full audio codes from coarse audio codes. + + Args: + x_coarse_gen (np.ndarray): The coarse audio codes to generate full audio codes from. + model (BarkModel): The BarkModel to use for generating the full audio codes. + history_prompt (tuple): A tuple of (semantic_history, coarse_history, fine_history) to use as a prompt for the generation. + temp (float): The temperature to use for the generation. + silent (bool): Whether to silence the tqdm progress bar. + base (tuple): A tuple of (semantic_history, coarse_history, fine_history) to use as a base for the generation. + + Returns: + np.ndarray: The generated full audio codes. + """ assert ( isinstance(x_coarse_gen, np.ndarray) and len(x_coarse_gen.shape) == 2 diff --git a/TTS/tts/models/bark.py b/TTS/tts/models/bark.py index ee3b820637..f198c3d58a 100644 --- a/TTS/tts/models/bark.py +++ b/TTS/tts/models/bark.py @@ -206,7 +206,7 @@ def synthesize( speaker_wav (str): Path to the speaker audio file for cloning a new voice. It is cloned and saved in `voice_dirs` with the name `speaker_id`. Defaults to None. voice_dirs (List[str]): List of paths that host reference audio files for speakers. Defaults to None. - **kwargs: Inference settings. See `inference()`. + **kwargs: Model specific inference settings used by `generate_audio()` and `TTS.tts.layers.bark.inference_funcs.generate_text_semantic(). Returns: A dictionary of the output values with `wav` as output waveform, `deterministic_seed` as seed used at inference,