Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Neural network widget that works in a separate thread #2958

Merged
merged 5 commits into from
Jun 1, 2018
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 158 additions & 3 deletions Orange/widgets/model/owneuralnetwork.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,49 @@
from functools import partial
import logging
import re
import sys
from unittest.mock import patch
import concurrent.futures

from AnyQt.QtWidgets import QApplication
from AnyQt.QtCore import Qt
from AnyQt.QtWidgets import QApplication, qApp
from AnyQt.QtCore import Qt, QThread
from AnyQt.QtCore import pyqtSlot as Slot

from Orange.data import Table
from Orange.modelling import NNLearner
from Orange.widgets import gui
from Orange.widgets.settings import Setting
from Orange.widgets.utils.owlearnerwidget import OWBaseLearner

from Orange.widgets.utils.concurrent import (
ThreadExecutor, FutureWatcher, methodinvoke
)



class Task:
"""
A class that will hold the state for an learner evaluation.
"""
future = ... # type: concurrent.futures.Future
watcher = ... # type: FutureWatcher
cancelled = False # type: bool

def cancel(self):
"""
Cancel the task.

Set the `cancelled` field to True and block until the future is done.
"""
# set cancelled state
self.cancelled = True
self.future.cancel()
concurrent.futures.wait([self.future])


class CancelThreadException(BaseException):
pass


class OWNNLearner(OWBaseLearner):
name = "Neural Network"
Expand Down Expand Up @@ -53,11 +87,20 @@ def add_main_layout(self):
label="Alpha:", decimals=5, alignment=Qt.AlignRight,
callback=self.settings_changed, controlWidth=80)
self.max_iter_spin = gui.spin(
box, self, "max_iterations", 10, 300, step=10,
box, self, "max_iterations", 10, 10000, step=10,
label="Max iterations:", orientation=Qt.Horizontal,
alignment=Qt.AlignRight, callback=self.settings_changed,
controlWidth=80)

def setup_layout(self):
super().setup_layout()

self._task = None # type: Optional[Task]
self._executor = ThreadExecutor()

# just a test cancel button
gui.button(self.controlArea, self, "Cancel", callback=self.cancel)

def create_learner(self):
return self.LEARNER(
hidden_layer_sizes=self.get_hidden_layers(),
Expand All @@ -81,6 +124,118 @@ def get_hidden_layers(self):
self.hidden_layers_edit.setText("100,")
return layers

def update_model(self):
self.show_fitting_failed(None)
self.model = None
if self.check_data():
self.__update()
else:
self.Outputs.model.send(self.model)

@Slot(float)
def setProgressValue(self, value):
assert self.thread() is QThread.currentThread()
self.progressBarSet(value)

def __update(self):
if self._task is not None:
# First make sure any pending tasks are cancelled.
self.cancel()
assert self._task is None

self.setBlocking(True)

self._task = task = Task()

# A thread safe way to invoke a method
set_progress = methodinvoke(self, "setProgressValue", (float,))

max_iter = self.learner.kwargs["max_iter"]

def callback(iteration=None):
if task.cancelled:
raise CancelThreadException() # this stop the thread
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment (and exception name) are not correct. This does not stop the thread. It cancels/interrupts the task.

if iteration is not None:
set_progress(iteration/max_iter*100)

def print_callback(*args, **kwargs):
iters = None
# try to parse iteration number
if args and args[0] and isinstance(args[0], str):
find = re.findall(r"Iteration (\d+)", args[0])
if find:
iters = int(find[0])
callback(iters)

def build_model(data, learner):
if learner.kwargs["solver"] != "lbfgs":
# enable verbose printouts within scikit and redirect them
with patch.dict(learner.kwargs, {"verbose": True}),\
patch("builtins.print", print_callback):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Time -->

Thread-1
-|patch|-----|exit|--------

Thread-2
----|patch|--------|exit|--
          ^             ^
         stores         restores
     patched print      the patched print of Thread-1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops. Thank you. Did not know that patch mutates state globally....

return learner(data)
else:
# lbfgs solver uses different mechanism
return learner(data)

build_model_func = partial(build_model, self.data, self.learner)

self.progressBarInit()

task.future = self._executor.submit(build_model_func)
task.watcher = FutureWatcher(task.future)
task.watcher.done.connect(self._task_finished)

@Slot(concurrent.futures.Future)
def _task_finished(self, f):
"""
Parameters
----------
f : Future
The future instance holding the built model
"""
assert self.thread() is QThread.currentThread()
assert self._task is not None
assert self._task.future is f
assert f.done()

self.setBlocking(False)

self._task = None
self.progressBarFinished()

try:
self.model = f.result()
except Exception as ex: # pylint: disable=broad-except
# Log the exception with a traceback
log = logging.getLogger()
log.exception(__name__, exc_info=True)
self.model = None
self.show_fitting_failed(ex)
else:
self.model.name = self.learner_name
self.model.instances = self.data
self.Outputs.model.send(self.model)

def cancel(self):
"""
Cancel the current task (if any).
"""
if self._task is not None:
self._task.cancel()
assert self._task.future.done()
# disconnect the `_task_finished` slot
self._task.watcher.done.disconnect(self._task_finished)
self._task = None
# threads use signals to run functions in the main thread and some
# can still be quoued (perhaps change)
qApp.processEvents()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this. If needed route the progress updates via an intermediary QObject maybe like this

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I haven't thought of this.

self.progressBarFinished()
self.setBlocking(False)

def onDeleteWidget(self):
self.cancel()
super().onDeleteWidget()


if __name__ == "__main__":
a = QApplication(sys.argv)
Expand Down
16 changes: 16 additions & 0 deletions Orange/widgets/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,13 +514,16 @@ def test_input_data(self):
self.assertEqual(self.widget.data, None)
self.send_signal("Data", self.data)
self.assertEqual(self.widget.data, self.data)
self.wait_until_stop_blocking()

def test_input_data_disconnect(self):
"""Check widget's data and model after disconnecting data from input"""
self.send_signal("Data", self.data)
self.assertEqual(self.widget.data, self.data)
self.widget.apply_button.button.click()
self.wait_until_stop_blocking()
self.send_signal("Data", None)
self.wait_until_stop_blocking()
self.assertEqual(self.widget.data, None)
self.assertIsNone(self.get_output(self.widget.Outputs.model))

Expand All @@ -529,9 +532,11 @@ def test_input_data_learner_adequacy(self):
for inadequate in self.inadequate_dataset:
self.send_signal("Data", inadequate)
self.widget.apply_button.button.click()
self.wait_until_stop_blocking()
self.assertTrue(self.widget.Error.data_error.is_shown())
for valid in self.valid_datasets:
self.send_signal("Data", valid)
self.wait_until_stop_blocking()
self.assertFalse(self.widget.Error.data_error.is_shown())

def test_input_preprocessor(self):
Expand All @@ -542,6 +547,7 @@ def test_input_preprocessor(self):
randomize, self.widget.preprocessors,
'Preprocessor not added to widget preprocessors')
self.widget.apply_button.button.click()
self.wait_until_stop_blocking()
self.assertEqual(
(randomize,), self.widget.learner.preprocessors,
'Preprocessors were not passed to the learner')
Expand All @@ -551,6 +557,7 @@ def test_input_preprocessors(self):
pp_list = PreprocessorList([Randomize(), RemoveNaNColumns()])
self.send_signal("Preprocessor", pp_list)
self.widget.apply_button.button.click()
self.wait_until_stop_blocking()
self.assertEqual(
(pp_list,), self.widget.learner.preprocessors,
'`PreprocessorList` was not added to preprocessors')
Expand All @@ -560,10 +567,12 @@ def test_input_preprocessor_disconnect(self):
randomize = Randomize()
self.send_signal("Preprocessor", randomize)
self.widget.apply_button.button.click()
self.wait_until_stop_blocking()
self.assertEqual(randomize, self.widget.preprocessors)

self.send_signal("Preprocessor", None)
self.widget.apply_button.button.click()
self.wait_until_stop_blocking()
self.assertIsNone(self.widget.preprocessors,
'Preprocessors not removed on disconnect.')

Expand All @@ -585,6 +594,7 @@ def test_output_model(self):
self.assertIsNone(self.get_output(self.widget.Outputs.model))
self.send_signal('Data', self.data)
self.widget.apply_button.button.click()
self.wait_until_stop_blocking()
model = self.get_output(self.widget.Outputs.model)
self.assertIsNotNone(model)
self.assertIsInstance(model, self.widget.LEARNER.__returns__)
Expand All @@ -598,6 +608,7 @@ def test_output_learner_name(self):
self.widget.name_line_edit.text())
self.widget.name_line_edit.setText(new_name)
self.widget.apply_button.button.click()
self.wait_until_stop_blocking()
self.assertEqual(self.get_output("Learner").name, new_name)

def test_output_model_name(self):
Expand All @@ -606,6 +617,7 @@ def test_output_model_name(self):
self.widget.name_line_edit.setText(new_name)
self.send_signal("Data", self.data)
self.widget.apply_button.button.click()
self.wait_until_stop_blocking()
self.assertEqual(self.get_output(self.widget.Outputs.model).name, new_name)

def _get_param_value(self, learner, param):
Expand All @@ -626,6 +638,7 @@ def test_parameters_default(self):
for dataset in self.valid_datasets:
self.send_signal("Data", dataset)
self.widget.apply_button.button.click()
self.wait_until_stop_blocking()
for parameter in self.parameters:
# Skip if the param isn't used for the given data type
if self._should_check_parameter(parameter, dataset):
Expand All @@ -639,6 +652,7 @@ def test_parameters(self):
# to only certain problem types
for dataset in self.valid_datasets:
self.send_signal("Data", dataset)
self.wait_until_stop_blocking()

for parameter in self.parameters:
# Skip if the param isn't used for the given data type
Expand All @@ -650,6 +664,7 @@ def test_parameters(self):
for value in parameter.values:
parameter.set_value(value)
self.widget.apply_button.button.click()
self.wait_until_stop_blocking()
param = self._get_param_value(self.widget.learner, parameter)
self.assertEqual(
param, parameter.get_value(),
Expand All @@ -674,6 +689,7 @@ def test_params_trigger_settings_changed(self):
"""Check that the learner gets updated whenever a param is changed."""
for dataset in self.valid_datasets:
self.send_signal("Data", dataset)
self.wait_until_stop_blocking()

for parameter in self.parameters:
# Skip if the param isn't used for the given data type
Expand Down