Skip to content

Commit

Permalink
Merge pull request #5710 from janezd/baselearnerwidget-pp-warning
Browse files Browse the repository at this point in the history
[ENH] Learner widgets: Inform about potential problems when overriding preprocessors
  • Loading branch information
lanzagar authored Jan 7, 2022
2 parents 5d83588 + e44754d commit 41f6906
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 26 deletions.
3 changes: 1 addition & 2 deletions Orange/widgets/model/owadaboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,7 @@ def set_base_learner(self, learner):
self.base_estimator = learner or self.DEFAULT_BASE_ESTIMATOR
self.base_label.setText(
"Base estimator: %s" % self.base_estimator.name.title())
if self.auto_apply:
self.apply()
self.learner = self.model = None

def get_learner_parameters(self):
return (("Base estimator", self.base_estimator),
Expand Down
2 changes: 1 addition & 1 deletion Orange/widgets/model/owcalibratedlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def add_main_layout(self):
def set_learner(self, learner):
self.base_learner = learner
self._set_default_name()
self.unconditional_apply()
self.learner = self.model = None

def _set_default_name(self):

Expand Down
3 changes: 2 additions & 1 deletion Orange/widgets/model/owcurvefit.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ def __insert_into_expression(self, what: str, offset=0):

def set_data(self, data: Optional[Table]):
self.Warning.data_missing(shown=not bool(data))
self.learner = None
super().set_data(data)
self.__clear()

Expand All @@ -419,7 +420,7 @@ def handleNewSignals(self):
self.__init_models()
self.__enable_controls()
self.__set_pending()
self.unconditional_apply()
super().handleNewSignals()

def __preprocess_data(self):
self.__pp_data = preprocess(self.data, self.preprocessors)
Expand Down
3 changes: 0 additions & 3 deletions Orange/widgets/model/owlinearregression.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,6 @@ def add_main_layout(self):
self.controls.alpha_index.setEnabled(self.reg_type != self.OLS)
self.l2_ratio_slider.setEnabled(self.reg_type == self.Elastic)

def handleNewSignals(self):
self.apply()

def _intercept_changed(self):
self.apply()

Expand Down
10 changes: 7 additions & 3 deletions Orange/widgets/model/owstack.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,26 @@ def add_main_layout(self):
@Inputs.learners
def set_learner(self, index: int, learner: Learner):
self.learners[index] = learner
self._invalidate()

@Inputs.learners.insert
def insert_learner(self, index, learner):
self.learners.insert(index, learner)
self._invalidate()

@Inputs.learners.remove
def remove_learner(self, index):
self.learners.pop(index)
self._invalidate()

@Inputs.aggregate
def set_aggregate(self, aggregate):
self.aggregate = aggregate
self._invalidate()

def handleNewSignals(self):
super().handleNewSignals()
self.apply()
def _invalidate(self):
self.learner = self.model = None
# ... and handleNewSignals will do the rest

def create_learner(self):
if not self.learners:
Expand Down
66 changes: 51 additions & 15 deletions Orange/widgets/utils/owlearnerwidget.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ class Error(OWWidget.Error):
class Warning(OWWidget.Warning):
outdated_learner = Msg("Press Apply to submit changes.")

class Information(OWWidget.Information):
ignored_preprocessors = Msg(
"Ignoring default preprocessing.\n"
"Default preprocessing, such as scaling, one-hot encoding and "
"treatment of missing data, has been replaced with user-specified "
"preprocessors. Problems may occur if these are inadequate "
"for the given data.")

class Inputs:
data = Input("Data", Table)
preprocessor = Input("Preprocessor", Preprocess)
Expand All @@ -90,6 +98,8 @@ class Outputs:

OUTPUT_MODEL_NAME = Outputs.model.name # Attr for backcompat w/ self.send() code

_SEND, _SOFT, _UPDATE = range(3)

def __init__(self, preprocessors=None):
super().__init__()
self.__default_learner_name = ""
Expand All @@ -99,6 +109,7 @@ def __init__(self, preprocessors=None):
self.model = None
self.preprocessors = preprocessors
self.outdated_settings = False
self.__apply_level = []

self.setup_layout()
QTimer.singleShot(0, getattr(self, "unconditional_apply", self.apply))
Expand Down Expand Up @@ -144,7 +155,8 @@ def set_default_learner_name(self, name: str) -> None:
@Inputs.preprocessor
def set_preprocessor(self, preprocessor):
self.preprocessors = preprocessor
self.apply()
# invalidate learner and model, so handleNewSignals will renew them
self.learner = self.model = None

@Inputs.data
@check_sql_input
Expand All @@ -164,23 +176,50 @@ def set_data(self, data):
"Select one with the Select Columns widget.")
self.data = None

self.update_model()
# invalidate the model so that handleNewSignals will update it
self.model = None


def apply(self):
level, self.__apply_level = max(self.__apply_level, default=self._UPDATE), []
"""Applies learner and sends new model."""
self.update_learner()
self.update_model()
if level == self._SEND:
self._send_learner()
self._send_model()
elif level == self._UPDATE:
self.update_learner()
self.update_model()
else:
self.learner or self.update_learner()
self.model or self.update_model()

def apply_as(self, level, unconditional=False):
self.__apply_level.append(level)
if unconditional:
self.unconditional_apply()
else:
self.apply()

def update_learner(self):
self.learner = self.create_learner()
if self.learner and issubclass(self.LEARNER, Fitter):
self.learner.use_default_preprocessors = True
if self.learner is not None:
self.learner.name = self.effective_learner_name()
self._send_learner()

def _send_learner(self):
self.Outputs.learner.send(self.learner)
self.outdated_settings = False
self.Warning.outdated_learner.clear()

def handleNewSignals(self):
self.apply_as(self._SOFT, True)
self.Information.ignored_preprocessors(
shown=not getattr(self.learner, "use_default_preprocessors", False)
and getattr(self.LEARNER, "preprocessors", False)
and self.preprocessors is not None)

def show_fitting_failed(self, exc):
"""Show error when fitting fails.
Derived widgets can override this to show more specific messages."""
Expand All @@ -197,6 +236,9 @@ def update_model(self):
else:
self.model.name = self.learner_name or self.captionTitle
self.model.instances = self.data
self._send_model()

def _send_model(self):
self.Outputs.model.send(self.model)

def check_data(self):
Expand All @@ -223,15 +265,12 @@ def settings_changed(self, *args, **kwargs):
self.Warning.outdated_learner(shown=not self.auto_apply)
self.apply()

def _change_name(self, instance, output):
if instance:
instance.name = self.effective_learner_name()
if self.auto_apply:
output.send(instance)

def learner_name_changed(self):
self._change_name(self.learner, self.Outputs.learner)
self._change_name(self.model, self.Outputs.model)
if self.model is not None:
self.model.name = self.effective_learner_name()
if self.learner is not None:
self.learner.name = self.effective_learner_name()
self.apply_as(self._SEND)

def effective_learner_name(self):
"""Return the effective learner name."""
Expand Down Expand Up @@ -272,7 +311,6 @@ def add_main_layout(self):
Override this method for laying out any learner-specific parameter controls.
See setup_layout() method for execution order.
"""
pass

def add_classification_layout(self, box):
"""Creates layout for classification specific options.
Expand All @@ -281,7 +319,6 @@ def add_classification_layout(self, box):
and regression learners require different options.
See `setup_layout()` method for execution order.
"""
pass

def add_regression_layout(self, box):
"""Creates layout for regression specific options.
Expand All @@ -290,7 +327,6 @@ def add_regression_layout(self, box):
and regression learners require different options.
See `setup_layout()` method for execution order.
"""
pass

def add_learner_name_widget(self):
self.name_line_edit = gui.lineEdit(
Expand Down
72 changes: 71 additions & 1 deletion Orange/widgets/utils/tests/test_owlearnerwidget.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import Mock
from unittest.mock import Mock, patch

import scipy.sparse as sp

Expand Down Expand Up @@ -218,3 +218,73 @@ def check_name(name):
check_name("Bar")
w.set_default_learner_name("")
check_name("Blarg")

def test_preprocessor_warning(self):
class TestLearnerNoPreprocess(Learner):
name = "Test"
__returns__ = Mock()

class TestWidgetNoPreprocess(OWBaseLearner):
name = "Test"
LEARNER = TestLearnerNoPreprocess

class TestLearnerPreprocess(Learner):
name = "Test"
preprocessors = [Mock()]
__returns__ = Mock()

class TestWidgetPreprocess(OWBaseLearner):
name = "Test"
LEARNER = TestLearnerPreprocess

class TestFitterPreprocess(Fitter):
name = "Test"
preprocessors = [Mock()]
__returns__ = Mock()

class TestWidgetPreprocessFit(OWBaseLearner):
name = "Test"
LEARNER = TestFitterPreprocess

wno = self.create_widget(TestWidgetNoPreprocess)
wyes = self.create_widget(TestWidgetPreprocess)
wfit = self.create_widget(TestWidgetPreprocessFit)

self.assertFalse(wno.Information.ignored_preprocessors.is_shown())
self.assertFalse(wyes.Information.ignored_preprocessors.is_shown())
self.assertFalse(wfit.Information.ignored_preprocessors.is_shown())

pp = continuize.Continuize()
self.send_signal(wno.Inputs.preprocessor, pp)
self.send_signal(wyes.Inputs.preprocessor, pp)
self.send_signal(wfit.Inputs.preprocessor, pp)

self.assertFalse(wno.Information.ignored_preprocessors.is_shown())
self.assertTrue(wyes.Information.ignored_preprocessors.is_shown())
self.assertFalse(wfit.Information.ignored_preprocessors.is_shown())

self.send_signal(wno.Inputs.preprocessor, None)
self.send_signal(wyes.Inputs.preprocessor, None)
self.send_signal(wfit.Inputs.preprocessor, None)

self.assertFalse(wno.Information.ignored_preprocessors.is_shown())
self.assertFalse(wyes.Information.ignored_preprocessors.is_shown())
self.assertFalse(wfit.Information.ignored_preprocessors.is_shown())

def test_multiple_sends(self):
class TestLearner(Learner):
name = "Test"
__returns__ = Mock()

class TestWidget(OWBaseLearner):
name = "Test"
LEARNER = TestLearner

widget = self.create_widget(TestWidget)
pp = continuize.Continuize()
with patch.object(widget.Outputs.learner, "send") as model_send, \
patch.object(widget.Outputs.model, "send") as learner_send:
self.send_signals([(widget.Inputs.data, self.iris),
(widget.Inputs.preprocessor, pp)])
learner_send.assert_called_once()
model_send.assert_called_once()

0 comments on commit 41f6906

Please sign in to comment.