-
-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ENH] Neural network widget that works in a separate thread #2958
Changes from 3 commits
5bb4d00
677e6f8
3fc7292
e7c89a2
48bb3bc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,49 @@ | ||
from functools import partial | ||
import logging | ||
import re | ||
import sys | ||
from unittest.mock import patch | ||
import concurrent.futures | ||
|
||
from AnyQt.QtWidgets import QApplication | ||
from AnyQt.QtCore import Qt | ||
from AnyQt.QtWidgets import QApplication, qApp | ||
from AnyQt.QtCore import Qt, QThread | ||
from AnyQt.QtCore import pyqtSlot as Slot | ||
|
||
from Orange.data import Table | ||
from Orange.modelling import NNLearner | ||
from Orange.widgets import gui | ||
from Orange.widgets.settings import Setting | ||
from Orange.widgets.utils.owlearnerwidget import OWBaseLearner | ||
|
||
from Orange.widgets.utils.concurrent import ( | ||
ThreadExecutor, FutureWatcher, methodinvoke | ||
) | ||
|
||
|
||
|
||
class Task: | ||
""" | ||
A class that will hold the state for an learner evaluation. | ||
""" | ||
future = ... # type: concurrent.futures.Future | ||
watcher = ... # type: FutureWatcher | ||
cancelled = False # type: bool | ||
|
||
def cancel(self): | ||
""" | ||
Cancel the task. | ||
|
||
Set the `cancelled` field to True and block until the future is done. | ||
""" | ||
# set cancelled state | ||
self.cancelled = True | ||
self.future.cancel() | ||
concurrent.futures.wait([self.future]) | ||
|
||
|
||
class CancelThreadException(BaseException): | ||
pass | ||
|
||
|
||
class OWNNLearner(OWBaseLearner): | ||
name = "Neural Network" | ||
|
@@ -53,11 +87,20 @@ def add_main_layout(self): | |
label="Alpha:", decimals=5, alignment=Qt.AlignRight, | ||
callback=self.settings_changed, controlWidth=80) | ||
self.max_iter_spin = gui.spin( | ||
box, self, "max_iterations", 10, 300, step=10, | ||
box, self, "max_iterations", 10, 10000, step=10, | ||
label="Max iterations:", orientation=Qt.Horizontal, | ||
alignment=Qt.AlignRight, callback=self.settings_changed, | ||
controlWidth=80) | ||
|
||
def setup_layout(self): | ||
super().setup_layout() | ||
|
||
self._task = None # type: Optional[Task] | ||
self._executor = ThreadExecutor() | ||
|
||
# just a test cancel button | ||
gui.button(self.controlArea, self, "Cancel", callback=self.cancel) | ||
|
||
def create_learner(self): | ||
return self.LEARNER( | ||
hidden_layer_sizes=self.get_hidden_layers(), | ||
|
@@ -81,6 +124,118 @@ def get_hidden_layers(self): | |
self.hidden_layers_edit.setText("100,") | ||
return layers | ||
|
||
def update_model(self): | ||
self.show_fitting_failed(None) | ||
self.model = None | ||
if self.check_data(): | ||
self.__update() | ||
else: | ||
self.Outputs.model.send(self.model) | ||
|
||
@Slot(float) | ||
def setProgressValue(self, value): | ||
assert self.thread() is QThread.currentThread() | ||
self.progressBarSet(value) | ||
|
||
def __update(self): | ||
if self._task is not None: | ||
# First make sure any pending tasks are cancelled. | ||
self.cancel() | ||
assert self._task is None | ||
|
||
self.setBlocking(True) | ||
|
||
self._task = task = Task() | ||
|
||
# A thread safe way to invoke a method | ||
set_progress = methodinvoke(self, "setProgressValue", (float,)) | ||
|
||
max_iter = self.learner.kwargs["max_iter"] | ||
|
||
def callback(iteration=None): | ||
if task.cancelled: | ||
raise CancelThreadException() # this stop the thread | ||
if iteration is not None: | ||
set_progress(iteration/max_iter*100) | ||
|
||
def print_callback(*args, **kwargs): | ||
iters = None | ||
# try to parse iteration number | ||
if args and args[0] and isinstance(args[0], str): | ||
find = re.findall(r"Iteration (\d+)", args[0]) | ||
if find: | ||
iters = int(find[0]) | ||
callback(iters) | ||
|
||
def build_model(data, learner): | ||
if learner.kwargs["solver"] != "lbfgs": | ||
# enable verbose printouts within scikit and redirect them | ||
with patch.dict(learner.kwargs, {"verbose": True}),\ | ||
patch("builtins.print", print_callback): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Whoops. Thank you. Did not know that patch mutates state globally.... |
||
return learner(data) | ||
else: | ||
# lbfgs solver uses different mechanism | ||
return learner(data) | ||
|
||
build_model_func = partial(build_model, self.data, self.learner) | ||
|
||
self.progressBarInit() | ||
|
||
task.future = self._executor.submit(build_model_func) | ||
task.watcher = FutureWatcher(task.future) | ||
task.watcher.done.connect(self._task_finished) | ||
|
||
@Slot(concurrent.futures.Future) | ||
def _task_finished(self, f): | ||
""" | ||
Parameters | ||
---------- | ||
f : Future | ||
The future instance holding the built model | ||
""" | ||
assert self.thread() is QThread.currentThread() | ||
assert self._task is not None | ||
assert self._task.future is f | ||
assert f.done() | ||
|
||
self.setBlocking(False) | ||
|
||
self._task = None | ||
self.progressBarFinished() | ||
|
||
try: | ||
self.model = f.result() | ||
except Exception as ex: # pylint: disable=broad-except | ||
# Log the exception with a traceback | ||
log = logging.getLogger() | ||
log.exception(__name__, exc_info=True) | ||
self.model = None | ||
self.show_fitting_failed(ex) | ||
else: | ||
self.model.name = self.learner_name | ||
self.model.instances = self.data | ||
self.Outputs.model.send(self.model) | ||
|
||
def cancel(self): | ||
""" | ||
Cancel the current task (if any). | ||
""" | ||
if self._task is not None: | ||
self._task.cancel() | ||
assert self._task.future.done() | ||
# disconnect the `_task_finished` slot | ||
self._task.watcher.done.disconnect(self._task_finished) | ||
self._task = None | ||
# threads use signals to run functions in the main thread and some | ||
# can still be quoued (perhaps change) | ||
qApp.processEvents() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove this. If needed route the progress updates via an intermediary QObject maybe like this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! I haven't thought of this. |
||
self.progressBarFinished() | ||
self.setBlocking(False) | ||
|
||
def onDeleteWidget(self): | ||
self.cancel() | ||
super().onDeleteWidget() | ||
|
||
|
||
if __name__ == "__main__": | ||
a = QApplication(sys.argv) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment (and exception name) are not correct. This does not stop the thread. It cancels/interrupts the task.