Skip to content

Commit

Permalink
Fitter: Change params uses default if None
Browse files Browse the repository at this point in the history
  • Loading branch information
pavlin-policar committed Mar 21, 2017
1 parent 571be99 commit f6695ca
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 5 deletions.
4 changes: 3 additions & 1 deletion Orange/modelling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def __kwargs(self, problem_type):
learner_kwargs = set(
self.__fits__[problem_type].__init__.__code__.co_varnames[1:])
changed_kwargs = self._change_kwargs(self.kwargs, problem_type)
return {k: v for k, v in changed_kwargs.items() if k in learner_kwargs}
# Make sure to remove any params that are set to None and use defaults
filtered_kwargs = {k: v for k, v in changed_kwargs.items() if v is not None}
return {k: v for k, v in filtered_kwargs.items() if k in learner_kwargs}

def _change_kwargs(self, kwargs, problem_type):
"""Handle the kwargs to be passed to the learner before they are used.
Expand Down
8 changes: 4 additions & 4 deletions Orange/modelling/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ class SGDLearner(Fitter):

def _change_kwargs(self, kwargs, problem_type):
if problem_type is self.CLASSIFICATION:
kwargs['loss'] = kwargs['classification_loss']
kwargs['epsilon'] = kwargs['classification_epsilon']
kwargs['loss'] = kwargs.get('classification_loss')
kwargs['epsilon'] = kwargs.get('classification_epsilon')
elif problem_type is self.REGRESSION:
kwargs['loss'] = kwargs['regression_loss']
kwargs['epsilon'] = kwargs['regression_epsilon']
kwargs['loss'] = kwargs.get('regression_loss')
kwargs['epsilon'] = kwargs.get('regression_epsilon')
return kwargs
40 changes: 40 additions & 0 deletions Orange/tests/test_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,43 @@ class DummyFitter(Fitter):
pp_data = fitter.preprocess(self.heart_disease)
self.assertTrue(not any(
isinstance(v, ContinuousVariable) for v in pp_data.domain.variables))

def test_default_kwargs_with_change_kwargs(self):
"""Fallback to default args in case specialized params not specified.
"""
class DummyClassificationLearner(LearnerClassification):
def __init__(self, param='classification_default', **_):
super().__init__()
self.param = param

def fit_storage(self, data):
return DummyModel(self.param)

class DummyRegressionLearner(LearnerRegression):
def __init__(self, param='regression_default', **_):
super().__init__()
self.param = param

def fit_storage(self, data):
return DummyModel(self.param)

class DummyModel:
def __init__(self, param):
self.param = param

class DummyFitter(Fitter):
__fits__ = {'classification': DummyClassificationLearner,
'regression': DummyRegressionLearner}

def _change_kwargs(self, kwargs, problem_type):
if problem_type == self.CLASSIFICATION:
kwargs['param'] = kwargs.get('classification_param')
else:
kwargs['param'] = kwargs.get('regression_param')
return kwargs

learner = DummyFitter()
iris, housing = Table('iris')[:5], Table('housing')[:5]
self.assertEqual(learner(iris).param, 'classification_default')
self.assertEqual(learner(housing).param, 'regression_default')

0 comments on commit f6695ca

Please sign in to comment.