Skip to content

Commit

Permalink
Add kwargs to ignore extra arguments w/o error (coqui-ai#2822)
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol authored and Tindell committed Aug 14, 2023
1 parent e3aab83 commit 5dd1e59
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 4 deletions.
55 changes: 52 additions & 3 deletions TTS/tts/layers/bark/inference_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,28 @@ def generate_text_semantic(
allow_early_stop=True,
base=None,
use_kv_caching=True,
**kwargs, # pylint: disable=unused-argument
):
"""Generate semantic tokens from text."""
"""Generate semantic tokens from text.
Args:
text (str): The text to generate semantic tokens from.
model (BarkModel): The BarkModel to use for generating the semantic tokens.
history_prompt (tuple): A tuple of (semantic_history, coarse_history, fine_history) to use as a prompt for the generation.
temp (float): The temperature to use for the generation.
top_k (int): The number of top tokens to consider for the generation.
top_p (float): The cumulative probability to consider for the generation.
silent (bool): Whether to silence the tqdm progress bar.
min_eos_p (float): The minimum probability to consider for the end of sentence token.
max_gen_duration_s (float): The maximum duration in seconds to generate for.
allow_early_stop (bool): Whether to allow the generation to stop early.
base (tuple): A tuple of (semantic_history, coarse_history, fine_history) to use as a base for the generation.
use_kv_caching (bool): Whether to use key-value caching for the generation.
**kwargs: Additional keyword arguments. They are ignored.
Returns:
np.ndarray: The generated semantic tokens.
"""
print(f"history_prompt in gen: {history_prompt}")
assert isinstance(text, str)
text = _normalize_whitespace(text)
Expand Down Expand Up @@ -298,7 +318,24 @@ def generate_coarse(
base=None,
use_kv_caching=True,
):
"""Generate coarse audio codes from semantic tokens."""
"""Generate coarse audio codes from semantic tokens.
Args:
x_semantic (np.ndarray): The semantic tokens to generate coarse audio codes from.
model (BarkModel): The BarkModel to use for generating the coarse audio codes.
history_prompt (tuple): A tuple of (semantic_history, coarse_history, fine_history) to use as a prompt for the generation.
temp (float): The temperature to use for the generation.
top_k (int): The number of top tokens to consider for the generation.
top_p (float): The cumulative probability to consider for the generation.
silent (bool): Whether to silence the tqdm progress bar.
max_coarse_history (int): The maximum number of coarse audio codes to use as history.
sliding_window_len (int): The length of the sliding window to use for the generation.
base (tuple): A tuple of (semantic_history, coarse_history, fine_history) to use as a base for the generation.
use_kv_caching (bool): Whether to use key-value caching for the generation.
Returns:
np.ndarray: The generated coarse audio codes.
"""
assert (
isinstance(x_semantic, np.ndarray)
and len(x_semantic.shape) == 1
Expand Down Expand Up @@ -453,7 +490,19 @@ def generate_fine(
silent=True,
base=None,
):
"""Generate full audio codes from coarse audio codes."""
"""Generate full audio codes from coarse audio codes.
Args:
x_coarse_gen (np.ndarray): The coarse audio codes to generate full audio codes from.
model (BarkModel): The BarkModel to use for generating the full audio codes.
history_prompt (tuple): A tuple of (semantic_history, coarse_history, fine_history) to use as a prompt for the generation.
temp (float): The temperature to use for the generation.
silent (bool): Whether to silence the tqdm progress bar.
base (tuple): A tuple of (semantic_history, coarse_history, fine_history) to use as a base for the generation.
Returns:
np.ndarray: The generated full audio codes.
"""
assert (
isinstance(x_coarse_gen, np.ndarray)
and len(x_coarse_gen.shape) == 2
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/models/bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def synthesize(
speaker_wav (str): Path to the speaker audio file for cloning a new voice. It is cloned and saved in
`voice_dirs` with the name `speaker_id`. Defaults to None.
voice_dirs (List[str]): List of paths that host reference audio files for speakers. Defaults to None.
**kwargs: Inference settings. See `inference()`.
**kwargs: Model specific inference settings used by `generate_audio()` and `TTS.tts.layers.bark.inference_funcs.generate_text_semantic().
Returns:
A dictionary of the output values with `wav` as output waveform, `deterministic_seed` as seed used at inference,
Expand Down

0 comments on commit 5dd1e59

Please sign in to comment.