diff --git a/benches/hyperfine.sh b/benches/hyperfine.sh index a663dc1..7910dc6 100755 --- a/benches/hyperfine.sh +++ b/benches/hyperfine.sh @@ -23,9 +23,9 @@ function speedtest { --style full \ --parameter-list sample_size $sample_sizes \ "$python unweighted --size {sample_size} $tmp_src" \ - "cargo run --release unweighted --size {sample_size} < $tmp_src" \ + "cargo run --release -- --size {sample_size} unweighted < $tmp_src" \ "$python weighted --size {sample_size} $tmp_src <(cut -f 8 < $tmp_src)" \ - "cargo run --release weighted --size {sample_size} $tmp_src <(cut -f 8 < $tmp_src)" + "cargo run --release -- --size {sample_size} weighted $tmp_src <(cut -f 8 < $tmp_src)" } diff --git a/src/main.rs b/src/main.rs index d29b081..1da1ae7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -35,6 +35,14 @@ use reservoir_sampling::{ struct Cli { #[clap(subcommand)] command: Commands, + + /// Seed for reproducibility + #[clap(long, parse(try_from_str))] + seed: Option, + + /// Sample size + #[clap(short, long, default_value_t=10)] + size: usize, } @@ -44,14 +52,6 @@ enum Commands { #[clap(arg_required_else_help=false, visible_alias="uw")] /// Unweighted resevoir sampling Unweighted { - /// Seed for reproducibility - #[clap(long, parse(try_from_str))] - seed: Option, - - /// Sample size - #[clap(short, long, default_value_t=10)] - size: usize, - /// Population file name. #[clap(name="population's file name")] population_fn: Option, @@ -60,14 +60,6 @@ enum Commands { #[clap(arg_required_else_help=false, visible_alias="w")] /// Weighted reservoir sampling Weighted { - /// Seed for reproducibility - #[clap(long, parse(try_from_str))] - seed: Option, - - /// Sample size - #[clap(short, long, default_value_t=10)] - size: usize, - /// Population file name. #[clap(name="population's file name")] population_fn: String, @@ -121,21 +113,20 @@ fn get_rng(seed: &Option) -> SmallRng fn main() { let args = Cli::parse(); + let mut rng = get_rng(&args.seed); match &args.command { - Commands::Unweighted { size, seed, population_fn } => { - let mut rng = get_rng(seed); + Commands::Unweighted { population_fn } => { let population = get_reader(population_fn); let samples = l( &mut population.lines().map(|v| v.unwrap()), - *size, + args.size, &mut rng); for sample in samples { println!("{}", sample); } } - Commands::Weighted { size, seed, weight_fn, population_fn } => { - let mut rng = get_rng(seed); + Commands::Weighted { weight_fn, population_fn } => { let weights = read_lines(weight_fn) .unwrap() .map(Result::unwrap) @@ -147,7 +138,7 @@ fn main() { let samples = a_exp_j( &mut weighted_samples, - *size, + args.size, &mut rng); for sample in samples { println!("{}", sample);