Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Fitter: Change params uses default if None #2127

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions Orange/modelling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Fitter(Learner, metaclass=FitterMeta):
learners.

"""
__fits__ = None
__fits__ = {}
__returns__ = Model

# Constants to indicate what kind of problem we're dealing with
Expand Down Expand Up @@ -83,7 +83,9 @@ def __kwargs(self, problem_type):
learner_kwargs = set(
self.__fits__[problem_type].__init__.__code__.co_varnames[1:])
changed_kwargs = self._change_kwargs(self.kwargs, problem_type)
return {k: v for k, v in changed_kwargs.items() if k in learner_kwargs}
# Make sure to remove any params that are set to None and use defaults
filtered_kwargs = {k: v for k, v in changed_kwargs.items() if v is not None}
return {k: v for k, v in filtered_kwargs.items() if k in learner_kwargs}

def _change_kwargs(self, kwargs, problem_type):
"""Handle the kwargs to be passed to the learner before they are used.
Expand All @@ -104,10 +106,9 @@ def supports_weights(self):
"""The fitter supports weights if both the classification and
regression learners support weights."""
return (
hasattr(self.get_learner(self.CLASSIFICATION), 'supports_weights')
and self.get_learner(self.CLASSIFICATION).supports_weights) and (
hasattr(self.get_learner(self.REGRESSION), 'supports_weights')
and self.get_learner(self.REGRESSION).supports_weights)
getattr(self.get_learner(self.CLASSIFICATION), 'supports_weights', False) and
getattr(self.get_learner(self.REGRESSION), 'supports_weights', False)
)

@property
def params(self):
Expand Down
8 changes: 4 additions & 4 deletions Orange/modelling/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ class SGDLearner(Fitter):

def _change_kwargs(self, kwargs, problem_type):
if problem_type is self.CLASSIFICATION:
kwargs['loss'] = kwargs['classification_loss']
kwargs['epsilon'] = kwargs['classification_epsilon']
kwargs['loss'] = kwargs.get('classification_loss')
kwargs['epsilon'] = kwargs.get('classification_epsilon')
elif problem_type is self.REGRESSION:
kwargs['loss'] = kwargs['regression_loss']
kwargs['epsilon'] = kwargs['regression_epsilon']
kwargs['loss'] = kwargs.get('regression_loss')
kwargs['epsilon'] = kwargs.get('regression_epsilon')
return kwargs
40 changes: 40 additions & 0 deletions Orange/tests/test_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,43 @@ class DummyFitter(Fitter):
pp_data = fitter.preprocess(self.heart_disease)
self.assertTrue(not any(
isinstance(v, ContinuousVariable) for v in pp_data.domain.variables))

def test_default_kwargs_with_change_kwargs(self):
"""Fallback to default args in case specialized params not specified.
"""
class DummyClassificationLearner(LearnerClassification):
def __init__(self, param='classification_default', **_):
super().__init__()
self.param = param

def fit_storage(self, data):
return DummyModel(self.param)

class DummyRegressionLearner(LearnerRegression):
def __init__(self, param='regression_default', **_):
super().__init__()
self.param = param

def fit_storage(self, data):
return DummyModel(self.param)

class DummyModel:
def __init__(self, param):
self.param = param

class DummyFitter(Fitter):
__fits__ = {'classification': DummyClassificationLearner,
'regression': DummyRegressionLearner}

def _change_kwargs(self, kwargs, problem_type):
if problem_type == self.CLASSIFICATION:
kwargs['param'] = kwargs.get('classification_param')
else:
kwargs['param'] = kwargs.get('regression_param')
return kwargs

learner = DummyFitter()
iris, housing = Table('iris')[:5], Table('housing')[:5]
self.assertEqual(learner(iris).param, 'classification_default')
self.assertEqual(learner(housing).param, 'regression_default')