Skip to content

Commit

Permalink
distill: add keep-best-steps option
Browse files Browse the repository at this point in the history
With this option only the parameter files for the N best steps are
retained during distillation.

(cherry picked from commit 68eb99c)
  • Loading branch information
danieldk committed Feb 5, 2021
1 parent 89f603a commit 2059314
Showing 1 changed file with 51 additions and 9 deletions.
60 changes: 51 additions & 9 deletions syntaxdot-cli/src/subcommands/distill.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::cell::RefCell;
use std::collections::btree_map::{BTreeMap, Entry};
use std::fs::File;
use std::collections::VecDeque;
use std::fs::{self, File};
use std::io::{BufReader, Read, Seek};

use anyhow::{Context, Result};
use anyhow::{bail, Context, Result};
use clap::{App, Arg, ArgMatches};
use indicatif::{ProgressBar, ProgressStyle};
use itertools::Itertools;
Expand Down Expand Up @@ -41,6 +42,7 @@ const GPU: &str = "GPU";
const HARD_LOSS: &str = "HARD_LOSS";
const INITIAL_LR_CLASSIFIER: &str = "INITIAL_LR_CLASSIFIER";
const INITIAL_LR_ENCODER: &str = "INITIAL_LR_ENCODER";
const KEEP_BEST_STEPS: &str = "KEEP_BEST_STEPS";
const LR_DECAY_RATE: &str = "LR_DECAY_RATE";
const LR_DECAY_STEPS: &str = "LR_DECAY_STEPS";
const MAX_LEN: &str = "MAX_LEN";
Expand All @@ -64,6 +66,7 @@ pub struct DistillApp {
device: Device,
eval_steps: usize,
hard_loss: bool,
keep_best_steps: Option<usize>,
max_len: Option<SequenceLength>,
mixed_precision: bool,
lr_schedules: RefCell<LearningRateSchedules>,
Expand Down Expand Up @@ -148,6 +151,8 @@ impl DistillApp {

let mut global_step = 0;

let mut best_step_paths = self.keep_best_steps.map(VecDeque::with_capacity);

let n_steps = self
.train_duration
.to_steps(&teacher_train_file, self.batch_size)
Expand Down Expand Up @@ -208,13 +213,14 @@ impl DistillApp {
best_step = global_step;
best_acc = acc;

student
.vs
.save(format!("distill-step-{}", global_step))
.context(format!(
"Cannot save variable store for step {}",
global_step
))?;
let step_path = format!("distill-step-{}", global_step);

student.vs.save(&step_path).context(format!(
"Cannot save variable store for step {}",
global_step
))?;

self.cleanup_old_best_steps(&mut best_step_paths, step_path);
}

let step_status = if best_step == global_step { "🎉" } else { "" };
Expand All @@ -233,6 +239,23 @@ impl DistillApp {
Ok(())
}

fn cleanup_old_best_steps(
&self,
best_step_paths: &mut Option<VecDeque<String>>,
step_path: String,
) {
if let Some(best_step_paths) = best_step_paths.as_mut() {
if best_step_paths.len() == self.keep_best_steps.unwrap() {
let cleanup_step = best_step_paths.pop_front().expect("No steps?");
if let Err(err) = fs::remove_file(&cleanup_step) {
eprintln!("Cannot remove step parameters {}: {}", cleanup_step, err);
}
}

best_step_paths.push_back(step_path);
}
}

fn student_loss(
&self,
teacher: &BertModel,
Expand Down Expand Up @@ -660,6 +683,12 @@ impl SyntaxDotApp for DistillApp {
.help("Initial encoder learning rate")
.default_value("5e-5"),
)
.arg(
Arg::with_name(KEEP_BEST_STEPS)
.long("keep-best-steps")
.value_name("N")
.help("Only keep the N best steps"),
)
.arg(
Arg::with_name(MIXED_PRECISION)
.long("mixed-precision")
Expand Down Expand Up @@ -752,6 +781,18 @@ impl SyntaxDotApp for DistillApp {
.parse()
.context("Cannot parse initial encoder learning rate")?;
let summary_writer = SummaryOption::parse(matches)?;

let keep_best_steps = matches
.value_of(KEEP_BEST_STEPS)
.map(|n| {
n.parse()
.context("Cannot parse number of best steps to keep")
})
.transpose()?;
if keep_best_steps == Some(0) {
bail!("Refusing to keep zero steps")
}

let lr_decay_rate = matches
.value_of(LR_DECAY_RATE)
.unwrap()
Expand Down Expand Up @@ -800,6 +841,7 @@ impl SyntaxDotApp for DistillApp {
device,
eval_steps,
hard_loss,
keep_best_steps,
max_len,
mixed_precision,
lr_schedules: RefCell::new(Self::create_lr_schedules(
Expand Down

0 comments on commit 2059314

Please sign in to comment.