Skip to content

Commit

Permalink
owneuralnetwork: connect callbacks to NN through n_iters_
Browse files Browse the repository at this point in the history
MLPClassifier n_iters_ was made a property which calls a callback.
  • Loading branch information
markotoplak committed Mar 19, 2018
1 parent 3fc7292 commit e7c89a2
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 27 deletions.
5 changes: 4 additions & 1 deletion Orange/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,8 +383,11 @@ def __call__(self, data):
m.params = self.params
return m

def _initialize_wrapped(self):
return self.__wraps__(**self.params)

def fit(self, X, Y, W=None):
clf = self.__wraps__(**self.params)
clf = self._initialize_wrapped()
Y = Y.reshape(-1)
if W is None or not self.supports_weights:
return self.__returns__(clf.fit(X, Y))
Expand Down
25 changes: 24 additions & 1 deletion Orange/classification/neural_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,28 @@
__all__ = ["NNClassificationLearner"]


class NIterCallbackMixin:
orange_callback = None

@property
def n_iter_(self):
return self.__orange_n_iter

@n_iter_.setter
def n_iter_(self, v):
self.__orange_n_iter = v
if self.orange_callback:
self.orange_callback(v)


class MLPClassifierWCallback(skl_nn.MLPClassifier, NIterCallbackMixin):
pass


class NNClassificationLearner(NNBase, SklLearner):
__wraps__ = skl_nn.MLPClassifier
__wraps__ = MLPClassifierWCallback

def _initialize_wrapped(self):
clf = SklLearner._initialize_wrapped(self)
clf.orange_callback = getattr(self, "callback", None)
return clf
7 changes: 7 additions & 0 deletions Orange/modelling/neural_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,10 @@
class NNLearner(SklFitter):
__fits__ = {'classification': NNClassificationLearner,
'regression': NNRegressionLearner}

callback = None

def get_learner(self, problem_type):
learner = super().get_learner(problem_type)
learner.callback = self.callback
return learner
12 changes: 11 additions & 1 deletion Orange/regression/neural_network.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
import sklearn.neural_network as skl_nn
from Orange.base import NNBase
from Orange.regression import SklLearner
from Orange.classification.neural_network import NIterCallbackMixin

__all__ = ["NNRegressionLearner"]


class MLPRegressorWCallback(skl_nn.MLPRegressor, NIterCallbackMixin):
pass


class NNRegressionLearner(NNBase, SklLearner):
__wraps__ = skl_nn.MLPRegressor
__wraps__ = MLPRegressorWCallback

def _initialize_wrapped(self):
clf = SklLearner._initialize_wrapped(self)
clf.orange_callback = getattr(self, "callback", None)
return clf
35 changes: 11 additions & 24 deletions Orange/widgets/model/owneuralnetwork.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from functools import partial
import copy
import logging
import re
import sys
from unittest.mock import patch
import concurrent.futures

from AnyQt.QtWidgets import QApplication, qApp
Expand All @@ -20,7 +20,6 @@
)



class Task:
"""
A class that will hold the state for an learner evaluation.
Expand Down Expand Up @@ -152,32 +151,20 @@ def __update(self):

max_iter = self.learner.kwargs["max_iter"]

def callback(iteration=None):
def callback(iteration):
if task.cancelled:
raise CancelThreadException() # this stop the thread
if iteration is not None:
set_progress(iteration/max_iter*100)

def print_callback(*args, **kwargs):
iters = None
# try to parse iteration number
if args and args[0] and isinstance(args[0], str):
find = re.findall(r"Iteration (\d+)", args[0])
if find:
iters = int(find[0])
callback(iters)
set_progress(iteration/max_iter*100)

# copy to set the callback so that the learner output is not modified
# (currently we can not pass callbacks to learners __call__)
learner = copy.copy(self.learner)
learner.callback = callback

def build_model(data, learner):
if learner.kwargs["solver"] != "lbfgs":
# enable verbose printouts within scikit and redirect them
with patch.dict(learner.kwargs, {"verbose": True}),\
patch("builtins.print", print_callback):
return learner(data)
else:
# lbfgs solver uses different mechanism
return learner(data)

build_model_func = partial(build_model, self.data, self.learner)
return learner(data)

build_model_func = partial(build_model, self.data, learner)

self.progressBarInit()

Expand Down

0 comments on commit e7c89a2

Please sign in to comment.