Skip to content

Commit

Permalink
Make it work on mps
Browse files Browse the repository at this point in the history
  • Loading branch information
gravityrail committed Jun 15, 2024
1 parent f5b81c9 commit bd2f992
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 10 deletions.
44 changes: 37 additions & 7 deletions TTS/tts/layers/xtts/stream_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,44 @@ def generate(
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs

if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
pad_token_tensor = torch.tensor([generation_config.pad_token_id], device=inputs_tensor.device) if generation_config.pad_token_id is not None else None
eos_token_tensor = torch.tensor([generation_config.eos_token_id], device=inputs_tensor.device) if generation_config.eos_token_id is not None else None
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor,
pad_token_tensor,
eos_token_tensor,
if (
model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask
):
pad_token_tensor = (
torch.tensor([generation_config.pad_token_id], device=inputs_tensor.device)
if generation_config.pad_token_id is not None
else None
)
eos_token_tensor = (
torch.tensor([generation_config.eos_token_id], device=inputs_tensor.device)
if generation_config.eos_token_id is not None
else None
)

# hack to produce attention mask for mps devices since transformers bails but pytorch supports torch.isin on mps now
# for this to work, you must run with PYTORCH_ENABLE_MPS_FALLBACK=1 and call model.to(mps_device) on the XttsModel
if inputs_tensor.device.type == "mps":
default_attention_mask = torch.ones(inputs_tensor.shape[:2], dtype=torch.long, device=inputs_tensor.device)

is_pad_token_in_inputs = (pad_token_tensor is not None) and (
torch.isin(elements=inputs_tensor, test_elements=pad_token_tensor).any()
)
is_pad_token_not_equal_to_eos_token_id = (eos_token_tensor is None) or ~(
torch.isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
)
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
attention_mask_from_padding = inputs_tensor.ne(pad_token_tensor).long()

model_kwargs["attention_mask"] = (
attention_mask_from_padding * can_infer_attention_mask
+ default_attention_mask * ~can_infer_attention_mask
)
else:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor,
pad_token_tensor,
eos_token_tensor,
)

# decoder-only models should use left-padding for generation
if not self.config.is_encoder_decoder:
Expand Down
7 changes: 4 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ numpy==1.22.0;python_version<="3.10"
numpy>=1.24.3;python_version>"3.10"
cython>=0.29.30
scipy>=1.11.2
torch>=2.1
torchaudio
torch==2.3.1
torchaudio==2.3.1
torchvision==0.18.1
soundfile>=0.12.0
librosa>=0.10.0
scikit-learn>=1.3.0
Expand Down Expand Up @@ -48,7 +49,7 @@ bnnumerizer
bnunicodenormalizer
#deps for tortoise
einops>=0.6.0
transformers>=4.33.0
transformers>=4.41.2
#deps for bark
encodec>=0.1.1
# deps for XTTS
Expand Down

0 comments on commit bd2f992

Please sign in to comment.