-
-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3291 from ajdapretnar/introduce-stacking
[ENH] Introduce stacking
- Loading branch information
Showing
11 changed files
with
336 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,4 @@ | |
# pylint: disable=wildcard-import | ||
|
||
from .ada_boost import * | ||
from .stack import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import numpy as np | ||
|
||
from Orange.base import Learner, Model | ||
from Orange.modelling import Fitter | ||
from Orange.classification import LogisticRegressionLearner | ||
from Orange.classification.base_classification import LearnerClassification | ||
from Orange.data import Domain, ContinuousVariable, Table | ||
from Orange.evaluation import CrossValidation | ||
from Orange.regression import RidgeRegressionLearner | ||
from Orange.regression.base_regression import LearnerRegression | ||
|
||
|
||
__all__ = ['StackedLearner', 'StackedClassificationLearner', | ||
'StackedRegressionLearner', 'StackedFitter'] | ||
|
||
|
||
class StackedModel(Model): | ||
def __init__(self, models, aggregate, use_prob=True, domain=None): | ||
super().__init__(domain=domain) | ||
self.models = models | ||
self.aggregate = aggregate | ||
self.use_prob = use_prob | ||
|
||
def predict_storage(self, data): | ||
if self.use_prob: | ||
probs = [m(data, Model.Probs) for m in self.models] | ||
X = np.hstack(probs) | ||
else: | ||
pred = [m(data) for m in self.models] | ||
X = np.column_stack(pred) | ||
Y = np.repeat(np.nan, X.shape[0]) | ||
stacked_data = data.transform(self.aggregate.domain) | ||
stacked_data.X = X | ||
stacked_data.Y = Y | ||
return self.aggregate( | ||
stacked_data, Model.ValueProbs if self.use_prob else Model.Value) | ||
|
||
|
||
class StackedLearner(Learner): | ||
""" | ||
Constructs a stacked model by fitting an aggregator | ||
over the results of base models. | ||
K-fold cross-validation is used to get predictions of the base learners | ||
and fit the aggregator to obtain a stacked model. | ||
Args: | ||
learners (list): | ||
list of `Learner`s used for base models | ||
aggregate (Learner): | ||
Learner used to fit the meta model, aggregating predictions | ||
of base models | ||
k (int): | ||
number of folds for cross-validation | ||
Returns: | ||
instance of StackedModel | ||
""" | ||
|
||
__returns__ = StackedModel | ||
|
||
def __init__(self, learners, aggregate, k=5, preprocessors=None): | ||
super().__init__(preprocessors=preprocessors) | ||
self.learners = learners | ||
self.aggregate = aggregate | ||
self.k = k | ||
self.params = vars() | ||
|
||
def fit_storage(self, data): | ||
res = CrossValidation(data, self.learners, k=self.k) | ||
if data.domain.class_var.is_discrete: | ||
X = np.hstack(res.probabilities) | ||
use_prob = True | ||
else: | ||
X = res.predicted.T | ||
use_prob = False | ||
dom = Domain([ContinuousVariable('f{}'.format(i + 1)) | ||
for i in range(X.shape[1])], | ||
data.domain.class_var) | ||
stacked_data = data.transform(dom) | ||
stacked_data.X = X | ||
stacked_data.Y = res.actual | ||
models = [l(data) for l in self.learners] | ||
aggregate_model = self.aggregate(stacked_data) | ||
return StackedModel(models, aggregate_model, use_prob=use_prob, | ||
domain=data.domain) | ||
|
||
|
||
class StackedClassificationLearner(StackedLearner, LearnerClassification): | ||
""" | ||
Subclass of StackedLearner intended for classification tasks. | ||
Same as the super class, but has a default | ||
classification-specific aggregator (`LogisticRegressionLearner`). | ||
""" | ||
|
||
def __init__(self, learners, aggregate=LogisticRegressionLearner(), k=5, | ||
preprocessors=None): | ||
super().__init__(learners, aggregate, k=k, preprocessors=preprocessors) | ||
|
||
|
||
class StackedRegressionLearner(StackedLearner, LearnerRegression): | ||
""" | ||
Subclass of StackedLearner intended for regression tasks. | ||
Same as the super class, but has a default | ||
regression-specific aggregator (`RidgeRegressionLearner`). | ||
""" | ||
def __init__(self, learners, aggregate=RidgeRegressionLearner(), k=5, | ||
preprocessors=None): | ||
super().__init__(learners, aggregate, k=k, preprocessors=preprocessors) | ||
|
||
|
||
class StackedFitter(Fitter): | ||
__fits__ = {'classification': StackedClassificationLearner, | ||
'regression': StackedRegressionLearner} | ||
|
||
def __init__(self, learners, **kwargs): | ||
kwargs['learners'] = learners | ||
super().__init__(**kwargs) | ||
|
||
|
||
if __name__ == '__main__': | ||
import Orange | ||
iris = Table('iris') | ||
knn = Orange.modelling.KNNLearner() | ||
tree = Orange.modelling.TreeLearner() | ||
sl = StackedFitter([tree, knn]) | ||
m = sl(iris[::2]) | ||
print(m(iris[1::2], Model.Value)) | ||
|
||
housing = Table('housing') | ||
sl = StackedFitter([tree, knn]) | ||
m = sl(housing[::2]) | ||
print(list(zip(housing[1:10:2].Y, m(housing[1:10:2], Model.Value)))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import unittest | ||
|
||
from Orange.data import Table | ||
from Orange.ensembles.stack import StackedFitter | ||
from Orange.evaluation import CA, CrossValidation, MSE | ||
from Orange.modelling import KNNLearner, TreeLearner | ||
|
||
|
||
class TestStackedFitter(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
cls.iris = Table('iris') | ||
cls.housing = Table('housing') | ||
|
||
def test_classification(self): | ||
sf = StackedFitter([TreeLearner(), KNNLearner()]) | ||
results = CrossValidation(self.iris, [sf], k=3) | ||
ca = CA(results) | ||
self.assertGreater(ca, 0.9) | ||
|
||
def test_regression(self): | ||
sf = StackedFitter([TreeLearner(), KNNLearner()]) | ||
results = CrossValidation(self.housing[:50], | ||
[sf, TreeLearner(), KNNLearner()], k=3, | ||
random_state=0) | ||
mse = MSE()(results) | ||
self.assertLess(mse[0], mse[1]) | ||
self.assertLess(mse[0], mse[2]) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from collections import OrderedDict | ||
|
||
from Orange.base import Learner | ||
from Orange.data import Table | ||
from Orange.ensembles.stack import StackedFitter | ||
from Orange.widgets.settings import Setting | ||
from Orange.widgets.utils.owlearnerwidget import OWBaseLearner | ||
from Orange.widgets.widget import Input | ||
|
||
|
||
class OWStackedLearner(OWBaseLearner): | ||
name = "Stacking" | ||
description = "Stack multiple models." | ||
icon = "icons/Stacking.svg" | ||
priority = 100 | ||
|
||
LEARNER = StackedFitter | ||
|
||
learner_name = Setting("Stack") | ||
|
||
class Inputs(OWBaseLearner.Inputs): | ||
learners = Input("Learners", Learner, multiple=True) | ||
aggregate = Input("Aggregate", Learner) | ||
|
||
def __init__(self): | ||
self.learners = OrderedDict() | ||
self.aggregate = None | ||
super().__init__() | ||
|
||
def add_main_layout(self): | ||
pass | ||
|
||
@Inputs.learners | ||
def set_learners(self, learner, id): | ||
if id in self.learners and learner is None: | ||
del self.learners[id] | ||
elif learner is not None: | ||
self.learners[id] = learner | ||
self.apply() | ||
|
||
@Inputs.aggregate | ||
def set_aggregate(self, aggregate): | ||
self.aggregate = aggregate | ||
self.apply() | ||
|
||
def create_learner(self): | ||
if not self.learners: | ||
return None | ||
return self.LEARNER( | ||
tuple(self.learners.values()), aggregate=self.aggregate, | ||
preprocessors=self.preprocessors) | ||
|
||
def get_learner_parameters(self): | ||
return (("Base learners", [l.name for l in self.learners.values()]), | ||
("Aggregator", | ||
self.aggregate.name if self.aggregate else 'default')) | ||
|
||
|
||
if __name__ == "__main__": | ||
import sys | ||
from AnyQt.QtWidgets import QApplication | ||
|
||
a = QApplication(sys.argv) | ||
ow = OWStackedLearner() | ||
d = Table(sys.argv[1] if len(sys.argv) > 1 else 'iris') | ||
ow.set_data(d) | ||
ow.show() | ||
a.exec_() | ||
ow.saveSettings() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Test methods with long descriptive names can omit docstrings | ||
# pylint: disable=missing-docstring | ||
from Orange.data import Table | ||
from Orange.widgets.model.owstack import OWStackedLearner | ||
from Orange.classification import LogisticRegressionLearner | ||
from Orange.widgets.tests.base import WidgetTest | ||
|
||
|
||
class TestOWStackedLearner(WidgetTest): | ||
def setUp(self): | ||
self.widget = self.create_widget(OWStackedLearner, | ||
stored_settings={"auto_apply": False}) | ||
self.data = Table('iris') | ||
|
||
def test_input_data(self): | ||
"""Check widget's data with data on the input""" | ||
self.assertEqual(self.widget.data, None) | ||
self.send_signal("Data", self.data) | ||
self.assertEqual(self.widget.data, self.data) | ||
self.wait_until_stop_blocking() | ||
|
||
def test_output_learner(self): | ||
"""Check if learner is on output after apply""" | ||
self.assertIsNone(self.get_output(self.widget.Outputs.model)) | ||
self.send_signal("Learners", LogisticRegressionLearner(), 0) | ||
self.widget.apply_button.button.click() | ||
initial = self.get_output("Learner") | ||
self.assertIsNotNone(initial, "Does not initialize the learner output") | ||
self.widget.apply_button.button.click() | ||
newlearner = self.get_output("Learner") | ||
self.assertIsNot(initial, newlearner, | ||
"Does not send a new learner instance on `Apply`.") | ||
self.assertIsNotNone(newlearner) | ||
self.assertIsInstance(newlearner, self.widget.LEARNER) | ||
|
||
def test_output_model(self): | ||
"""Check if model is on output after sending data and apply""" | ||
self.assertIsNone(self.get_output(self.widget.Outputs.model)) | ||
self.send_signal("Learners", LogisticRegressionLearner(), 0) | ||
self.widget.apply_button.button.click() | ||
self.assertIsNone(self.get_output(self.widget.Outputs.model)) | ||
self.send_signal('Data', self.data) | ||
self.widget.apply_button.button.click() | ||
self.wait_until_stop_blocking() | ||
model = self.get_output(self.widget.Outputs.model) | ||
self.assertIsNotNone(model) | ||
self.assertIsInstance(model, self.widget.LEARNER.__returns__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+146 KB
doc/visual-programming/source/widgets/model/images/Stacking-Example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+7.71 KB
doc/visual-programming/source/widgets/model/images/Stacking-stamped.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
Stacking | ||
======== | ||
|
||
Stack multiple models. | ||
|
||
Inputs | ||
Data | ||
input dataset | ||
Preprocessor | ||
preprocessing method(s) | ||
Learners | ||
learning algorithm | ||
Aggregate | ||
model aggregation method | ||
|
||
Outputs | ||
Learner | ||
aggregated (stacked) learning algorithm | ||
Model | ||
trained model | ||
|
||
|
||
**Stacking** is an ensemble method that computes a meta model from several base models. The **Stacking** widget has the **Aggregate** input, which provides a method for aggregating the input models. If no aggregation input is given the default methods are used. Those are **Logistic Regression** for classification and **Ridge Regression** for regression problems. | ||
|
||
.. figure:: images/Stacking-stamped.png | ||
:scale: 50% | ||
|
||
1. The meta learner can be given a name under which it will appear in other widgets. The default name is “Stack”. | ||
2. Click *Apply* to commit the aggregated model. That will put the new learner in the output and, if the training examples are given, construct a new model and output it as well. To communicate changes automatically tick *Apply Automatically*. | ||
3. Access help and produce a report. | ||
|
||
Example | ||
------- | ||
|
||
We will use **Paint Data** to demonstrate how the widget is used. We painted a complex dataset with 4 class labels and sent it to **Test & Score**. We also provided three **kNN** learners, each with a different parameters (number of neighbors is 5, 10 or 15). Evaluation results are good, but can we do better? | ||
|
||
Let's use **Stacking**. **Stacking** requires several learners on the input and an aggregation method. In our case, this is **Logistic Regression**. A constructed meta learner is then sent to **Test & Score**. Results have improved, even if only marginally. **Stacking** normally works well on complex data sets. | ||
|
||
.. figure:: images/Stacking-Example.png | ||
|