Skip to content

Commit

Permalink
A try to hack the callback into skl NN by making n_iters_ a property.
Browse files Browse the repository at this point in the history
  • Loading branch information
markotoplak committed Mar 16, 2018
1 parent 3fc7292 commit c6dcd6a
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 25 deletions.
25 changes: 24 additions & 1 deletion Orange/classification/neural_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,28 @@
__all__ = ["NNClassificationLearner"]


class MLPClassifierWCallback(skl_nn.MLPClassifier):

orange_callback = None

@property
def n_iter_(self):
return self.__orange_n_iter

@n_iter_.setter
def n_iter_(self, v):
self.__orange_n_iter = v
if self.orange_callback:
self.orange_callback(v)


class NNClassificationLearner(NNBase, SklLearner):
__wraps__ = skl_nn.MLPClassifier
__wraps__ = MLPClassifierWCallback

def fit(self, X, Y, W=None):
clf = self.__wraps__(**self.params)
clf.orange_callback = getattr(self, "callback", None)
Y = Y.reshape(-1)
if W is None or not self.supports_weights:
return self.__returns__(clf.fit(X, Y))
return self.__returns__(clf.fit(X, Y, sample_weight=W.reshape(-1)))
9 changes: 9 additions & 0 deletions Orange/modelling/neural_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,12 @@
class NNLearner(SklFitter):
__fits__ = {'classification': NNClassificationLearner,
'regression': NNRegressionLearner}

def __init__(self, callback=None, **kwargs):
super().__init__(**kwargs)
self.callback = callback

def get_learner(self, problem_type):
learner = super().get_learner(problem_type)
learner.callback = self.callback
return learner
41 changes: 17 additions & 24 deletions Orange/widgets/model/owneuralnetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from Orange.widgets import gui
from Orange.widgets.settings import Setting
from Orange.widgets.utils.owlearnerwidget import OWBaseLearner
from Orange.modelling import Fitter

from Orange.widgets.utils.concurrent import (
ThreadExecutor, FutureWatcher, methodinvoke
Expand Down Expand Up @@ -101,14 +102,15 @@ def setup_layout(self):
# just a test cancel button
gui.button(self.controlArea, self, "Cancel", callback=self.cancel)

def create_learner(self):
def create_learner(self, callback=None):
return self.LEARNER(
hidden_layer_sizes=self.get_hidden_layers(),
activation=self.activation[self.activation_index],
solver=self.solver[self.solver_index],
alpha=self.alpha,
max_iter=self.max_iterations,
preprocessors=self.preprocessors)
preprocessors=self.preprocessors,
callback=callback)

def get_learner_parameters(self):
return (("Hidden layers", ', '.join(map(str, self.get_hidden_layers()))),
Expand All @@ -124,6 +126,14 @@ def get_hidden_layers(self):
self.hidden_layers_edit.setText("100,")
return layers

def new_model_learner(self, callback):
learner = self.create_learner(callback)
if learner and issubclass(self.LEARNER, Fitter):
learner.use_default_preprocessors = True
if self.learner is not None:
learner.name = self.learner_name
return learner

def update_model(self):
self.show_fitting_failed(None)
self.model = None
Expand Down Expand Up @@ -152,32 +162,15 @@ def __update(self):

max_iter = self.learner.kwargs["max_iter"]

def callback(iteration=None):
def callback(iteration):
if task.cancelled:
raise CancelThreadException() # this stop the thread
if iteration is not None:
set_progress(iteration/max_iter*100)

def print_callback(*args, **kwargs):
iters = None
# try to parse iteration number
if args and args[0] and isinstance(args[0], str):
find = re.findall(r"Iteration (\d+)", args[0])
if find:
iters = int(find[0])
callback(iters)
set_progress(iteration/max_iter*100)

def build_model(data, learner):
if learner.kwargs["solver"] != "lbfgs":
# enable verbose printouts within scikit and redirect them
with patch.dict(learner.kwargs, {"verbose": True}),\
patch("builtins.print", print_callback):
return learner(data)
else:
# lbfgs solver uses different mechanism
return learner(data)

build_model_func = partial(build_model, self.data, self.learner)
return learner(data)

build_model_func = partial(build_model, self.data, self.new_model_learner(callback))

self.progressBarInit()

Expand Down

0 comments on commit c6dcd6a

Please sign in to comment.