Skip to content

Commit

Permalink
Merge pull request #12 from LSSTDESC/issue/9/consolidate-e2e-script
Browse files Browse the repository at this point in the history
Remove e2e tests script, and replace it with equivalent unit tests
  • Loading branch information
drewoldag authored Oct 2, 2024
2 parents 423984b + afff765 commit 534533f
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 78 deletions.
32 changes: 0 additions & 32 deletions .github/workflows/e2e-tests.yml

This file was deleted.

8 changes: 4 additions & 4 deletions src/resspect/scripts/fit_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,20 +69,20 @@ def fit_dataset(user_choices):
features_file = user_choices.output
ncores = user_choices.ncores

if user_choices.sim_name in ['SNPCC', 'snpcc']:
if user_choices.sim_name.lower() == 'snpcc':
# fit the entire sample
fit_snpcc(path_to_data_dir=data_dir, features_file=features_file,
number_of_processors=ncores, feature_extractor=user_choices.function)

elif user_choices.sim_name in ['PLAsTiCC', 'PLASTICC', 'plasticc']:
elif user_choices.sim_name.lower() == 'plasticc':
fit_plasticc(path_photo_file=user_choices.photo_file,
path_header_file=user_choices.header_file,
output_file=features_file,
sample=user_choices.sample,
number_of_processors=ncores)

return None

else:
raise ValueError("-s or --simulation not recognized. Options are 'SNPCC' or 'PLAsTiCC'.")

def main():

Expand Down
34 changes: 21 additions & 13 deletions src/resspect/scripts/run_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = ['learn_loop', 'run_loop']
__all__ = ['run_loop']

from resspect.learn_loop import learn_loop

Expand Down Expand Up @@ -63,23 +63,31 @@ def run_loop(args):
"""

# set training sample variable
if args.training == 'original':
train = 'original'
elif isinstance(int(args.training), int):
train = int(args.training)
else:
raise ValueError('-t or --training option must be '
'"original" or integer!')

# run active learning loop
learn_loop(nloops=args.nquery, features_method=args.method,
learn_loop(nloops=args.nquery,
features_method=args.method,
classifier=args.classifier,
strategy=args.strategy, path_to_features=args.input,
strategy=args.strategy,
path_to_features=args.input,
output_metrics_file=args.metrics,
output_queried_file=args.queried,
training=train, batch=args.batch)
training=_parse_training(args.training),
batch=args.batch)

def _parse_training(training:str):
"""We don't check that `isinstance(training, str)` because `training` is defined
as a string in the argparse.ArgumentParser.
"""
# set training sample variable
if training.lower() == 'original':
train = 'original'
else:
try:
train = int(training)
except ValueError:
raise ValueError('-t or --training option must be "original" or integer!')

return train

def main():

Expand Down
29 changes: 0 additions & 29 deletions tests/resspect/run-snpcc-e2e.sh

This file was deleted.

25 changes: 25 additions & 0 deletions tests/resspect/test_learn_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,30 @@ def test_can_run_learn_loop(test_des_data_path):
output_queried_file=os.path.join(dir_name,"just_other_name.csv"),
)


def test_can_run_learn_loop_uncsample(test_des_data_path):
"""Test that learn_loop can load data and run.
This instance is distinct from the previous because it uses `UncSample` strategy
and runs for 2 loops instead of 1.
"""
with tempfile.TemporaryDirectory() as dir_name:
# Create the feature files to use for the learning loop.
output_file = os.path.join(dir_name, "output_file.dat")
fit_snpcc(
path_to_data_dir=test_des_data_path,
features_file=output_file
)

learn_loop(
nloops=2,
features_method="bazin",
strategy="UncSampling",
path_to_features=output_file,
output_metrics_file=os.path.join(dir_name,"just_a_name.csv"),
output_queried_file=os.path.join(dir_name,"just_other_name.csv"),
training=10,
)


if __name__ == '__main__':
pytest.main()
20 changes: 20 additions & 0 deletions tests/test_cli_scripts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest


def test_run_loop_arg_check():
"""Test that the logic that parses CLI inputs works as expected. Specifically
that the training sample is correctly parsed as either an integer or the string 'original'.
"""

from resspect.scripts.run_loop import _parse_training

assert _parse_training('original') == 'original'
assert _parse_training('OrigINAL') == 'original'
assert _parse_training('10') == 10
with pytest.raises(ValueError):
_parse_training('not_a_number')
with pytest.raises(ValueError):
_parse_training('1.0')
with pytest.raises(ValueError):
_parse_training('')
assert _parse_training('010') == 10

0 comments on commit 534533f

Please sign in to comment.