Skip to content

Commit

Permalink
Ensures that only GPT model is in training mode during XTTS GPT train…
Browse files Browse the repository at this point in the history
…ing (#3241)

* Ensures that only GPT model is in training mode during training

* Fix parallel wavegan unit test
  • Loading branch information
Edresson committed Nov 17, 2023
1 parent 14579a4 commit 11283fc
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
7 changes: 4 additions & 3 deletions TTS/tts/layers/xtts/trainer/gpt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,10 @@ def eval_step(self, batch, criterion):
batch["cond_idxs"] = None
return self.train_step(batch, criterion)

def on_epoch_start(self, trainer): # pylint: disable=W0613
# guarante that dvae will be in eval mode after .train() on evaluation end
self.dvae = self.dvae.eval()
def on_train_epoch_start(self, trainer):
trainer.model.eval() # the whole model to eval
# put gpt model in training mode
trainer.model.xtts.gpt.train()

def on_init_end(self, trainer): # pylint: disable=W0613
# ignore similarities.pth on clearml save/upload
Expand Down
1 change: 1 addition & 0 deletions TTS/vocoder/configs/parallel_wavegan_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class ParallelWaveganConfig(BaseGANVocoderConfig):
use_noise_augment: bool = False
use_cache: bool = True
steps_to_start_discriminator: int = 200000
target_loss: str = "loss_1"

# LOSS PARAMETERS - overrides
use_stft_loss: bool = True
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pandas>=1.4,<2.0
# deps for training
matplotlib>=3.7.0
# coqui stack
trainer
trainer>=0.0.32
# config management
coqpit>=0.0.16
# chinese g2p deps
Expand Down

0 comments on commit 11283fc

Please sign in to comment.