diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index 29e087ef06..a32ad00f56 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -125,7 +125,7 @@ def evaluation(model, criterion, data_loader, global_step): def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, global_step): model.train() - best_loss = float("inf") + best_loss = {"train_loss": None, "eval_loss": float("inf")} avg_loader_time = 0 end_time = time.time() for epoch in range(c.epochs):