Skip to content

Commit

Permalink
VizRankDialog: Use extended thread pool to prevent segfaults
Browse files Browse the repository at this point in the history
When fed large datasets, correlations widget exited with segmentation fault, (probably) due to insufficient stack size for created task.
  • Loading branch information
VesnaT committed Mar 15, 2019
1 parent b6a88e3 commit 5e14813
Show file tree
Hide file tree
Showing 5 changed files with 251 additions and 89 deletions.
24 changes: 14 additions & 10 deletions Orange/widgets/data/owcorrelations.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class KMeansCorrelationHeuristic:
def __init__(self, data):
self.n_attributes = len(data.domain.attributes)
self.data = data
self.states = None
self.clusters = None
self.n_clusters = int(np.sqrt(self.n_attributes))

def get_clusters_of_attributes(self):
Expand All @@ -84,16 +84,15 @@ def get_states(self, initial_state):
:param initial_state: initial state; None if this is the first call
:return: generator of tuples of states
"""
if self.states is not None:
return chain([initial_state], self.states)

clusters = self.get_clusters_of_attributes()
if self.clusters is None:
self.clusters = self.get_clusters_of_attributes()
clusters = self.clusters

# combinations within clusters
self.states = chain.from_iterable(combinations(cluster.instances, 2)
for cluster in clusters)
states0 = chain.from_iterable(combinations(cluster.instances, 2)
for cluster in clusters)
if self.n_clusters == 1:
return self.states
return states0

# combinations among clusters - closest clusters first
centroids = [c.centroid for c in clusters]
Expand All @@ -104,8 +103,13 @@ def get_states(self, initial_state):
states = ((min((c1, c2)), max((c1, c2))) for i in np.argsort(distances)
for c1 in clusters[cluster_combs[i][0]].instances
for c2 in clusters[cluster_combs[i][1]].instances)
self.states = chain(self.states, states)
return self.states
states = chain(states0, states)

if initial_state is not None:
while next(states) != initial_state:
pass
return chain([initial_state], states)
return states


class CorrelationRank(VizRankDialogAttrPair):
Expand Down
12 changes: 12 additions & 0 deletions Orange/widgets/data/tests/test_owcorrelations.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,14 @@ def test_row_for_state(self):
self.assertEqual(row[1].data(Qt.DisplayRole), self.attrs[0].name)
self.assertEqual(row[2].data(Qt.DisplayRole), self.attrs[1].name)

def test_iterate_states(self):
self.assertListEqual(list(self.vizrank.iterate_states(None)),
[(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)])
self.assertListEqual(list(self.vizrank.iterate_states((1, 0))),
[(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)])
self.assertListEqual(list(self.vizrank.iterate_states((2, 1))),
[(2, 1), (3, 0), (3, 1), (3, 2)])

def test_iterate_states_by_feature(self):
self.vizrank.sel_feature_index = 2
states = self.vizrank.iterate_states_by_feature()
Expand Down Expand Up @@ -345,3 +353,7 @@ def test_get_states_one_cluster(self):
states = set(heuristic.get_states(None))
self.assertEqual(len(states), 1)
self.assertSetEqual(states, {(0, 1)})


if __name__ == "__main__":
unittest.main()
17 changes: 10 additions & 7 deletions Orange/widgets/visualize/tests/test_owlinearprojection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from Orange.widgets.visualize.owlinearprojection import (
OWLinearProjection, LinearProjectionVizRank
)
from Orange.widgets.visualize.utils import Worker
from Orange.widgets.visualize.utils import run_vizrank


class TestOWLinearProjection(WidgetTest, AnchorProjectionWidgetTestMixin,
Expand Down Expand Up @@ -205,16 +205,14 @@ def setUp(self):

def test_discrete_class(self):
self.send_signal(self.widget.Inputs.data, self.data)
worker = Worker(self.vizrank)
self.vizrank.keep_running = True
worker.do_work()
run_vizrank(self.vizrank.compute_score,
self.vizrank.iterate_states(None), [], Mock())

def test_continuous_class(self):
data = Table("housing")[::100]
self.send_signal(self.widget.Inputs.data, data)
worker = Worker(self.vizrank)
self.vizrank.keep_running = True
worker.do_work()
run_vizrank(self.vizrank.compute_score,
self.vizrank.iterate_states(None), [], Mock())

def test_set_attrs(self):
self.send_signal(self.widget.Inputs.data, self.data)
Expand All @@ -230,3 +228,8 @@ def test_set_attrs(self):
self.assertNotEqual(self.widget.model_selected[:], model_selected)
c2 = self.get_output(self.widget.Outputs.components)
self.assertNotEqual(c1.domain.attributes, c2.domain.attributes)


if __name__ == "__main__":
import unittest
unittest.main()
132 changes: 132 additions & 0 deletions Orange/widgets/visualize/tests/test_vizrankdialog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from itertools import chain
import unittest
from unittest.mock import Mock
from queue import Queue

from AnyQt.QtGui import QStandardItem

from Orange.data import Table
from Orange.widgets.visualize.utils import (
VizRankDialog, Result, run_vizrank, QueuedScore
)
from Orange.widgets.tests.base import WidgetTest


def compute_score(x):
return (x[0] + 1) / (x[1] + 1)


class TestRunner(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.data = Table("iris")

def test_Result(self):
res = Result(queue=Queue(), scores=[])
self.assertIsInstance(res.queue, Queue)
self.assertIsInstance(res.scores, list)

def test_run_vizrank(self):
scores, task = [], Mock()
# run through all states
task.is_interruption_requested.return_value = False
states = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]
res = run_vizrank(compute_score, chain(states), scores, task)

next_state = self.assertQueueEqual(
res.queue, [0, 0, 0, 3, 2, 5], compute_score,
states, states[1:] + [None])
self.assertIsNone(next_state)
res_scores = sorted([compute_score(x) for x in states])
self.assertListEqual(res.scores, res_scores)
self.assertIsNot(scores, res.scores)
self.assertEqual(task.set_partial_result.call_count, 6)

def test_run_vizrank_interrupt(self):
scores, task = [], Mock()
# interrupt calculation in third iteration
task.is_interruption_requested.side_effect = lambda: \
True if task.is_interruption_requested.call_count > 2 else False
states = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]
res = run_vizrank(compute_score, chain(states), scores, task)

next_state = self.assertQueueEqual(
res.queue, [0, 0], compute_score, states[:2], states[1:3])
self.assertEqual(next_state, (0, 3))
res_scores = sorted([compute_score(x) for x in states[:2]])
self.assertListEqual(res.scores, res_scores)
self.assertIsNot(scores, res.scores)
self.assertEqual(task.set_partial_result.call_count, 2)

# continue calculation through all states
task.is_interruption_requested.side_effect = lambda: False
i = states.index(next_state)
res = run_vizrank(compute_score, chain(states[i:]), res_scores, task)

next_state = self.assertQueueEqual(
res.queue, [0, 3, 2, 5], compute_score, states[2:],
states[3:] + [None])
self.assertIsNone(next_state)
res_scores = sorted([compute_score(x) for x in states])
self.assertListEqual(res.scores, res_scores)
self.assertIsNot(scores, res.scores)
self.assertEqual(task.set_partial_result.call_count, 6)

def assertQueueEqual(self, queue, positions, f, states, next_states):
self.assertIsInstance(queue, Queue)
for qs in (QueuedScore(position=p, score=f(s), state=s, next_state=ns)
for p, s, ns in zip(positions, states, next_states)):
result = queue.get_nowait()
self.assertEqual(result.position, qs.position)
self.assertEqual(result.state, qs.state)
self.assertEqual(result.next_state, qs.next_state)
self.assertEqual(result.score, qs.score)
next_state = result.next_state
return next_state


class TestVizRankDialog(WidgetTest):
def test_on_partial_result(self):
def iterate_states(initial_state):
if initial_state is not None:
return chain(states[states.index(initial_state):])
return chain(states)

def invoke_on_partial_result():
widget.on_partial_result(run_vizrank(
widget.compute_score,
widget.iterate_states(widget.saved_state),
widget.scores, task
))

task = Mock()
states = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]

widget = VizRankDialog(None)
widget.progressBarInit()
widget.compute_score = compute_score
widget.iterate_states = iterate_states
widget.row_for_state = lambda sc, _: [QStandardItem(str(sc))]

# interrupt calculation in third iteration
task.is_interruption_requested.side_effect = lambda: \
True if task.is_interruption_requested.call_count > 2 else False
invoke_on_partial_result()
self.assertEqual(widget.rank_model.rowCount(), 2)
for row, score in enumerate(
sorted([compute_score(x) for x in states[:2]])):
self.assertEqual(widget.rank_model.item(row, 0).text(), str(score))
self.assertEqual(widget.saved_progress, 2)

# continue calculation through all states
task.is_interruption_requested.side_effect = lambda: False
invoke_on_partial_result()
self.assertEqual(widget.rank_model.rowCount(), 6)
for row, score in enumerate(
sorted([compute_score(x) for x in states])):
self.assertEqual(widget.rank_model.item(row, 0).text(), str(score))
self.assertEqual(widget.saved_progress, 6)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 5e14813

Please sign in to comment.