-
-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ENH] Neural network widget that works in a separate thread #2958
Merged
Merged
Changes from 4 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
5bb4d00
owneuralnetwork: execution in a threat and progress bar
markotoplak 677e6f8
Modelling widget tests support asynchronic excution
markotoplak 3fc7292
owneuralnetwork: raise max interations limit
markotoplak e7c89a2
owneuralnetwork: connect callbacks to NN through n_iters_
markotoplak 48bb3bc
owneuralnetwork: Atomic disconnect from task state update notifiers
ales-erjavec File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,19 @@ | ||
import sklearn.neural_network as skl_nn | ||
from Orange.base import NNBase | ||
from Orange.regression import SklLearner | ||
from Orange.classification.neural_network import NIterCallbackMixin | ||
|
||
__all__ = ["NNRegressionLearner"] | ||
|
||
|
||
class MLPRegressorWCallback(skl_nn.MLPRegressor, NIterCallbackMixin): | ||
pass | ||
|
||
|
||
class NNRegressionLearner(NNBase, SklLearner): | ||
__wraps__ = skl_nn.MLPRegressor | ||
__wraps__ = MLPRegressorWCallback | ||
|
||
def _initialize_wrapped(self): | ||
clf = SklLearner._initialize_wrapped(self) | ||
clf.orange_callback = getattr(self, "callback", None) | ||
return clf |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,48 @@ | ||
from functools import partial | ||
import copy | ||
import logging | ||
import re | ||
import sys | ||
import concurrent.futures | ||
|
||
from AnyQt.QtWidgets import QApplication | ||
from AnyQt.QtCore import Qt | ||
from AnyQt.QtWidgets import QApplication, qApp | ||
from AnyQt.QtCore import Qt, QThread | ||
from AnyQt.QtCore import pyqtSlot as Slot | ||
|
||
from Orange.data import Table | ||
from Orange.modelling import NNLearner | ||
from Orange.widgets import gui | ||
from Orange.widgets.settings import Setting | ||
from Orange.widgets.utils.owlearnerwidget import OWBaseLearner | ||
|
||
from Orange.widgets.utils.concurrent import ( | ||
ThreadExecutor, FutureWatcher, methodinvoke | ||
) | ||
|
||
|
||
class Task: | ||
""" | ||
A class that will hold the state for an learner evaluation. | ||
""" | ||
future = ... # type: concurrent.futures.Future | ||
watcher = ... # type: FutureWatcher | ||
cancelled = False # type: bool | ||
|
||
def cancel(self): | ||
""" | ||
Cancel the task. | ||
|
||
Set the `cancelled` field to True and block until the future is done. | ||
""" | ||
# set cancelled state | ||
self.cancelled = True | ||
self.future.cancel() | ||
concurrent.futures.wait([self.future]) | ||
|
||
|
||
class CancelThreadException(BaseException): | ||
pass | ||
|
||
|
||
class OWNNLearner(OWBaseLearner): | ||
name = "Neural Network" | ||
|
@@ -53,11 +86,20 @@ def add_main_layout(self): | |
label="Alpha:", decimals=5, alignment=Qt.AlignRight, | ||
callback=self.settings_changed, controlWidth=80) | ||
self.max_iter_spin = gui.spin( | ||
box, self, "max_iterations", 10, 300, step=10, | ||
box, self, "max_iterations", 10, 10000, step=10, | ||
label="Max iterations:", orientation=Qt.Horizontal, | ||
alignment=Qt.AlignRight, callback=self.settings_changed, | ||
controlWidth=80) | ||
|
||
def setup_layout(self): | ||
super().setup_layout() | ||
|
||
self._task = None # type: Optional[Task] | ||
self._executor = ThreadExecutor() | ||
|
||
# just a test cancel button | ||
gui.button(self.controlArea, self, "Cancel", callback=self.cancel) | ||
|
||
def create_learner(self): | ||
return self.LEARNER( | ||
hidden_layer_sizes=self.get_hidden_layers(), | ||
|
@@ -81,6 +123,106 @@ def get_hidden_layers(self): | |
self.hidden_layers_edit.setText("100,") | ||
return layers | ||
|
||
def update_model(self): | ||
self.show_fitting_failed(None) | ||
self.model = None | ||
if self.check_data(): | ||
self.__update() | ||
else: | ||
self.Outputs.model.send(self.model) | ||
|
||
@Slot(float) | ||
def setProgressValue(self, value): | ||
assert self.thread() is QThread.currentThread() | ||
self.progressBarSet(value) | ||
|
||
def __update(self): | ||
if self._task is not None: | ||
# First make sure any pending tasks are cancelled. | ||
self.cancel() | ||
assert self._task is None | ||
|
||
self.setBlocking(True) | ||
|
||
self._task = task = Task() | ||
|
||
# A thread safe way to invoke a method | ||
set_progress = methodinvoke(self, "setProgressValue", (float,)) | ||
|
||
max_iter = self.learner.kwargs["max_iter"] | ||
|
||
def callback(iteration): | ||
if task.cancelled: | ||
raise CancelThreadException() # this stop the thread | ||
set_progress(iteration/max_iter*100) | ||
|
||
# copy to set the callback so that the learner output is not modified | ||
# (currently we can not pass callbacks to learners __call__) | ||
learner = copy.copy(self.learner) | ||
learner.callback = callback | ||
|
||
def build_model(data, learner): | ||
return learner(data) | ||
|
||
build_model_func = partial(build_model, self.data, learner) | ||
|
||
self.progressBarInit() | ||
|
||
task.future = self._executor.submit(build_model_func) | ||
task.watcher = FutureWatcher(task.future) | ||
task.watcher.done.connect(self._task_finished) | ||
|
||
@Slot(concurrent.futures.Future) | ||
def _task_finished(self, f): | ||
""" | ||
Parameters | ||
---------- | ||
f : Future | ||
The future instance holding the built model | ||
""" | ||
assert self.thread() is QThread.currentThread() | ||
assert self._task is not None | ||
assert self._task.future is f | ||
assert f.done() | ||
|
||
self.setBlocking(False) | ||
|
||
self._task = None | ||
self.progressBarFinished() | ||
|
||
try: | ||
self.model = f.result() | ||
except Exception as ex: # pylint: disable=broad-except | ||
# Log the exception with a traceback | ||
log = logging.getLogger() | ||
log.exception(__name__, exc_info=True) | ||
self.model = None | ||
self.show_fitting_failed(ex) | ||
else: | ||
self.model.name = self.learner_name | ||
self.model.instances = self.data | ||
self.Outputs.model.send(self.model) | ||
|
||
def cancel(self): | ||
""" | ||
Cancel the current task (if any). | ||
""" | ||
if self._task is not None: | ||
self._task.cancel() | ||
assert self._task.future.done() | ||
# disconnect the `_task_finished` slot | ||
self._task.watcher.done.disconnect(self._task_finished) | ||
self._task = None | ||
# threads use signals to run functions in the main thread and some | ||
# can still be quoued (perhaps change) | ||
qApp.processEvents() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove this. If needed route the progress updates via an intermediary QObject maybe like this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! I haven't thought of this. |
||
self.progressBarFinished() | ||
self.setBlocking(False) | ||
|
||
def onDeleteWidget(self): | ||
self.cancel() | ||
super().onDeleteWidget() | ||
|
||
|
||
if __name__ == "__main__": | ||
a = QApplication(sys.argv) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment (and exception name) are not correct. This does not stop the thread. It cancels/interrupts the task.