Skip to content

Commit

Permalink
Merge pull request #2152 from jerneju/value-scatterplot
Browse files Browse the repository at this point in the history
[FIX] Scatter Plot: dealing with scipy sparse matrix
  • Loading branch information
nikicc authored Apr 21, 2017
2 parents 8e13aeb + cceeee3 commit 9274a20
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 15 deletions.
37 changes: 36 additions & 1 deletion Orange/widgets/visualize/owscatterplot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import scipy.sparse as sp

from AnyQt.QtCore import Qt, QTimer
from AnyQt.QtGui import (
Expand Down Expand Up @@ -394,7 +395,8 @@ def set_subset_data(self, subset_data):

# called when all signals are received, so the graph is updated only once
def handleNewSignals(self):
self.graph.new_data(self.data_metas_X, self.subset_data)
self.graph.new_data(self.sparse_to_dense(self.data_metas_X),
self.sparse_to_dense(self.subset_data))
if self.attribute_selection_list and \
all(attr in self.graph.domain
for attr in self.attribute_selection_list):
Expand All @@ -407,6 +409,37 @@ def handleNewSignals(self):
self.apply_selection()
self.unconditional_commit()

def prepare_data(self):
"""
Only when dealing with sparse matrices.
GH-2152
"""
self.graph.new_data(self.sparse_to_dense(self.data_metas_X),
self.sparse_to_dense(self.subset_data),
new=False)

def sparse_to_dense(self, input_data=None):
self.vizrank_button.setEnabled(not (self.data and self.data.is_sparse()))
if input_data is None or not input_data.is_sparse():
return input_data
keys = []
attrs = {self.attr_x,
self.attr_y,
self.graph.attr_color,
self.graph.attr_shape,
self.graph.attr_size,
self.graph.attr_label}
for i, attr in enumerate(input_data.domain):
if attr in attrs:
keys.append(i)
new_domain = input_data.domain.select_columns(keys)
dmx = Table.from_table(new_domain, input_data)
dmx.X = dmx.X.toarray()
# TODO: remove once we make sure Y is always dense.
if sp.issparse(dmx.Y):
dmx.Y = dmx.Y.toarray()
return dmx

def apply_selection(self):
"""Apply selection saved in workflow."""
if self.data is not None and self.selection is not None:
Expand Down Expand Up @@ -441,12 +474,14 @@ def set_attr(self, attr_x, attr_y):
self.update_attr()

def update_attr(self):
self.prepare_data()
self.update_graph()
self.cb_class_density.setEnabled(self.graph.can_draw_density())
self.cb_reg_line.setEnabled(self.graph.can_draw_regresssion_line())
self.send_features()

def update_colors(self):
self.prepare_data()
self.cb_class_density.setEnabled(self.graph.can_draw_density())

def update_density(self):
Expand Down
38 changes: 24 additions & 14 deletions Orange/widgets/visualize/owscatterplotgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,17 +591,18 @@ def update_tooltip(self, modifiers):
text = self.tiptexts.get(int(modifiers), self.tiptexts[0])
self.tip_textitem.setHtml(text)

def new_data(self, data, subset_data=None, **args):
self.plot_widget.clear()
self.remove_legend()
def new_data(self, data, subset_data=None, new=True, **args):
if new:
self.plot_widget.clear()
self.remove_legend()

self.density_img = None
self.scatterplot_item = None
self.scatterplot_item_sel = None
self.reg_line_item = None
self.labels = []
self.selection = None
self.valid_data = None
self.density_img = None
self.scatterplot_item = None
self.scatterplot_item_sel = None
self.reg_line_item = None
self.labels = []
self.selection = None
self.valid_data = None

self.subset_indices = set(e.id for e in subset_data) if subset_data else None

Expand Down Expand Up @@ -776,13 +777,15 @@ def compute_sizes(self):
return size_data

def update_sizes(self):
self.master.prepare_data()
self.update_point_size()

def update_point_size(self):
if self.scatterplot_item:
size_data = self.compute_sizes()
self.scatterplot_item.setSize(size_data)
self.scatterplot_item_sel.setSize(size_data + SELECTION_WIDTH)

update_point_size = update_sizes

def get_color_index(self):
if self.attr_color is None:
return -1
Expand Down Expand Up @@ -907,6 +910,9 @@ def make_pen(color, width):

def update_colors(self, keep_colors=False):
self.master.update_colors()
self.update_alpha_value(keep_colors)

def update_alpha_value(self, keep_colors=False):
if self.scatterplot_item:
pen_data, brush_data = self.compute_colors(keep_colors)
pen_data_sel, brush_data_sel = self.compute_colors_sel(keep_colors)
Expand All @@ -922,8 +928,6 @@ def update_colors(self, keep_colors=False):
elif self.density_img:
self.plot_widget.removeItem(self.density_img)

update_alpha_value = update_colors

def create_labels(self):
for x, y in zip(*self.scatterplot_item.getData()):
ti = TextItem()
Expand All @@ -937,6 +941,7 @@ def update_labels(self):
for label in self.labels:
label.setText("")
return
self.assure_attribute_present(self.attr_label)
if not self.labels:
self.create_labels()
label_column = self.data.get_column_view(self.attr_label)[0]
Expand Down Expand Up @@ -972,11 +977,16 @@ def compute_symbols(self):
return shape_data

def update_shapes(self):
self.assure_attribute_present(self.attr_shape)
if self.scatterplot_item:
shape_data = self.compute_symbols()
self.scatterplot_item.setSymbol(shape_data)
self.make_legend()

def assure_attribute_present(self, attr):
if attr not in self.data.domain:
self.master.prepare_data()

def update_grid(self):
self.plot_widget.showGrid(x=self.show_grid, y=self.show_grid)

Expand Down
18 changes: 18 additions & 0 deletions Orange/widgets/visualize/tests/test_owscatterplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# pylint: disable=missing-docstring
from unittest.mock import MagicMock
import numpy as np
import scipy.sparse as sp

from AnyQt.QtCore import QRectF, Qt

Expand Down Expand Up @@ -268,6 +269,23 @@ def test_set_strings_settings(self):
self.assertEqual(w.graph.attr_shape.name, "iris")
self.assertEqual(w.graph.attr_size.name, "petal width")

def test_sparse(self):
"""
Test sparse data.
GH-2152
GH-2157
"""
table = Table("iris")
table.X = sp.csr_matrix(table.X)
self.assertTrue(sp.issparse(table.X))
table.Y = sp.csr_matrix(table._Y) # pylint: disable=protected-access
self.assertTrue(sp.issparse(table.Y))
self.send_signal("Data", table)
self.widget.set_subset_data(table[:30])
data = self.get_output("Data")
self.assertTrue(data.is_sparse())
self.assertEqual(len(data.domain), 5)


if __name__ == "__main__":
import unittest
Expand Down

0 comments on commit 9274a20

Please sign in to comment.