Skip to content

Commit

Permalink
Merge pull request #2127 from pavlin-policar/fitter-change-kwargs-def…
Browse files Browse the repository at this point in the history
…ault-params

[FIX] Fitter: Change params uses default if None
  • Loading branch information
janezd authored Mar 27, 2017
2 parents a420f95 + 417e550 commit 1fc9cd8
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 10 deletions.
13 changes: 7 additions & 6 deletions Orange/modelling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Fitter(Learner, metaclass=FitterMeta):
learners.
"""
__fits__ = None
__fits__ = {}
__returns__ = Model

# Constants to indicate what kind of problem we're dealing with
Expand Down Expand Up @@ -83,7 +83,9 @@ def __kwargs(self, problem_type):
learner_kwargs = set(
self.__fits__[problem_type].__init__.__code__.co_varnames[1:])
changed_kwargs = self._change_kwargs(self.kwargs, problem_type)
return {k: v for k, v in changed_kwargs.items() if k in learner_kwargs}
# Make sure to remove any params that are set to None and use defaults
filtered_kwargs = {k: v for k, v in changed_kwargs.items() if v is not None}
return {k: v for k, v in filtered_kwargs.items() if k in learner_kwargs}

def _change_kwargs(self, kwargs, problem_type):
"""Handle the kwargs to be passed to the learner before they are used.
Expand All @@ -104,10 +106,9 @@ def supports_weights(self):
"""The fitter supports weights if both the classification and
regression learners support weights."""
return (
hasattr(self.get_learner(self.CLASSIFICATION), 'supports_weights')
and self.get_learner(self.CLASSIFICATION).supports_weights) and (
hasattr(self.get_learner(self.REGRESSION), 'supports_weights')
and self.get_learner(self.REGRESSION).supports_weights)
getattr(self.get_learner(self.CLASSIFICATION), 'supports_weights', False) and
getattr(self.get_learner(self.REGRESSION), 'supports_weights', False)
)

@property
def params(self):
Expand Down
8 changes: 4 additions & 4 deletions Orange/modelling/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ class SGDLearner(Fitter):

def _change_kwargs(self, kwargs, problem_type):
if problem_type is self.CLASSIFICATION:
kwargs['loss'] = kwargs['classification_loss']
kwargs['epsilon'] = kwargs['classification_epsilon']
kwargs['loss'] = kwargs.get('classification_loss')
kwargs['epsilon'] = kwargs.get('classification_epsilon')
elif problem_type is self.REGRESSION:
kwargs['loss'] = kwargs['regression_loss']
kwargs['epsilon'] = kwargs['regression_epsilon']
kwargs['loss'] = kwargs.get('regression_loss')
kwargs['epsilon'] = kwargs.get('regression_epsilon')
return kwargs
40 changes: 40 additions & 0 deletions Orange/tests/test_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,43 @@ class DummyFitter(Fitter):
pp_data = fitter.preprocess(self.heart_disease)
self.assertTrue(not any(
isinstance(v, ContinuousVariable) for v in pp_data.domain.variables))

def test_default_kwargs_with_change_kwargs(self):
"""Fallback to default args in case specialized params not specified.
"""
class DummyClassificationLearner(LearnerClassification):
def __init__(self, param='classification_default', **_):
super().__init__()
self.param = param

def fit_storage(self, data):
return DummyModel(self.param)

class DummyRegressionLearner(LearnerRegression):
def __init__(self, param='regression_default', **_):
super().__init__()
self.param = param

def fit_storage(self, data):
return DummyModel(self.param)

class DummyModel:
def __init__(self, param):
self.param = param

class DummyFitter(Fitter):
__fits__ = {'classification': DummyClassificationLearner,
'regression': DummyRegressionLearner}

def _change_kwargs(self, kwargs, problem_type):
if problem_type == self.CLASSIFICATION:
kwargs['param'] = kwargs.get('classification_param')
else:
kwargs['param'] = kwargs.get('regression_param')
return kwargs

learner = DummyFitter()
iris, housing = Table('iris')[:5], Table('housing')[:5]
self.assertEqual(learner(iris).param, 'classification_default')
self.assertEqual(learner(housing).param, 'regression_default')

0 comments on commit 1fc9cd8

Please sign in to comment.