Skip to content

Commit

Permalink
Merge pull request #140 from RaymondNie/master
Browse files Browse the repository at this point in the history
removed global var fs in train.py and replaced with hparams.sample_rate
  • Loading branch information
r9y9 authored Feb 26, 2019
2 parents 7c2b5f2 + 8dd39a3 commit 6fb72bf
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@
from warnings import warn
from hparams import hparams, hparams_debug_string

fs = hparams.sample_rate

global_step = 0
global_epoch = 0
use_cuda = torch.cuda.is_available()
Expand Down Expand Up @@ -424,7 +422,7 @@ def eval_model(global_step, writer, device, model, checkpoint_dir, ismultispeake

try:
writer.add_audio("(Eval) Predicted audio signal {}_{}".format(idx, speaker_str),
signal, global_step, sample_rate=fs)
signal, global_step, sample_rate=hparams.sample_rate)
except Exception as e:
warn(str(e))
pass
Expand Down Expand Up @@ -481,7 +479,7 @@ def save_states(global_step, writer, mel_outputs, linear_outputs, attn, mel, y,
path = join(checkpoint_dir, "step{:09d}_predicted.wav".format(
global_step))
try:
writer.add_audio("Predicted audio signal", signal, global_step, sample_rate=fs)
writer.add_audio("Predicted audio signal", signal, global_step, sample_rate=hparams.sample_rate)
except Exception as e:
warn(str(e))
pass
Expand Down Expand Up @@ -681,7 +679,7 @@ def train(device, model, data_loader, optimizer, writer,

# linear:
if train_postnet:
n_priority_freq = int(hparams.priority_freq / (fs * 0.5) * linear_dim)
n_priority_freq = int(hparams.priority_freq / (hparams.sample_rate * 0.5) * linear_dim)
linear_l1_loss, linear_binary_div = spec_loss(
linear_outputs[:, :-r, :], y[:, r:, :], target_mask,
priority_bin=n_priority_freq,
Expand Down

0 comments on commit 6fb72bf

Please sign in to comment.