From f6695ca2f8d5c27d3d311821874a29a65205cccc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pavlin=20Poli=C4=8Dar?= Date: Tue, 21 Mar 2017 16:38:31 +0100 Subject: [PATCH] Fitter: Change params uses default if None --- Orange/modelling/base.py | 4 +++- Orange/modelling/linear.py | 8 ++++---- Orange/tests/test_fitter.py | 40 +++++++++++++++++++++++++++++++++++++ 3 files changed, 47 insertions(+), 5 deletions(-) diff --git a/Orange/modelling/base.py b/Orange/modelling/base.py index 735432fb436..d2faed26137 100644 --- a/Orange/modelling/base.py +++ b/Orange/modelling/base.py @@ -83,7 +83,9 @@ def __kwargs(self, problem_type): learner_kwargs = set( self.__fits__[problem_type].__init__.__code__.co_varnames[1:]) changed_kwargs = self._change_kwargs(self.kwargs, problem_type) - return {k: v for k, v in changed_kwargs.items() if k in learner_kwargs} + # Make sure to remove any params that are set to None and use defaults + filtered_kwargs = {k: v for k, v in changed_kwargs.items() if v is not None} + return {k: v for k, v in filtered_kwargs.items() if k in learner_kwargs} def _change_kwargs(self, kwargs, problem_type): """Handle the kwargs to be passed to the learner before they are used. diff --git a/Orange/modelling/linear.py b/Orange/modelling/linear.py index 75e811b2a34..1c76feed64f 100644 --- a/Orange/modelling/linear.py +++ b/Orange/modelling/linear.py @@ -11,9 +11,9 @@ class SGDLearner(Fitter): def _change_kwargs(self, kwargs, problem_type): if problem_type is self.CLASSIFICATION: - kwargs['loss'] = kwargs['classification_loss'] - kwargs['epsilon'] = kwargs['classification_epsilon'] + kwargs['loss'] = kwargs.get('classification_loss') + kwargs['epsilon'] = kwargs.get('classification_epsilon') elif problem_type is self.REGRESSION: - kwargs['loss'] = kwargs['regression_loss'] - kwargs['epsilon'] = kwargs['regression_epsilon'] + kwargs['loss'] = kwargs.get('regression_loss') + kwargs['epsilon'] = kwargs.get('regression_epsilon') return kwargs diff --git a/Orange/tests/test_fitter.py b/Orange/tests/test_fitter.py index c92445511fa..cc80be36750 100644 --- a/Orange/tests/test_fitter.py +++ b/Orange/tests/test_fitter.py @@ -152,3 +152,43 @@ class DummyFitter(Fitter): pp_data = fitter.preprocess(self.heart_disease) self.assertTrue(not any( isinstance(v, ContinuousVariable) for v in pp_data.domain.variables)) + + def test_default_kwargs_with_change_kwargs(self): + """Fallback to default args in case specialized params not specified. + """ + class DummyClassificationLearner(LearnerClassification): + def __init__(self, param='classification_default', **_): + super().__init__() + self.param = param + + def fit_storage(self, data): + return DummyModel(self.param) + + class DummyRegressionLearner(LearnerRegression): + def __init__(self, param='regression_default', **_): + super().__init__() + self.param = param + + def fit_storage(self, data): + return DummyModel(self.param) + + class DummyModel: + def __init__(self, param): + self.param = param + + class DummyFitter(Fitter): + __fits__ = {'classification': DummyClassificationLearner, + 'regression': DummyRegressionLearner} + + def _change_kwargs(self, kwargs, problem_type): + if problem_type == self.CLASSIFICATION: + kwargs['param'] = kwargs.get('classification_param') + else: + kwargs['param'] = kwargs.get('regression_param') + return kwargs + + learner = DummyFitter() + iris, housing = Table('iris')[:5], Table('housing')[:5] + self.assertEqual(learner(iris).param, 'classification_default') + self.assertEqual(learner(housing).param, 'regression_default') +