diff --git a/Orange/statistics/util.py b/Orange/statistics/util.py index 75e3104014d..0a28b5a806e 100644 --- a/Orange/statistics/util.py +++ b/Orange/statistics/util.py @@ -290,7 +290,7 @@ def nanmean(x): return np.nansum(x.data) / n_values -def unique(x, return_counts=True): +def unique(x, return_counts=False): """ Equivalent of np.unique that supports sparse or dense matrices. """ if not sp.issparse(x): return np.unique(x, return_counts=return_counts) diff --git a/Orange/widgets/data/owfile.py b/Orange/widgets/data/owfile.py index db477b8ab99..5bc01ab2c64 100644 --- a/Orange/widgets/data/owfile.py +++ b/Orange/widgets/data/owfile.py @@ -403,11 +403,6 @@ def apply_domain_edit(self): if self.data is not None: domain, cols = self.domain_editor.get_domain(self.data.domain, self.data) X, y, m = cols - X = np.array(X).T if len(X) else np.empty((len(self.data), 0)) - y = np.array(y).T if len(y) else None - dtpe = object if any(isinstance(m, StringVariable) - for m in domain.metas) else float - m = np.array(m, dtype=dtpe).T if len(m) else None table = Table.from_numpy(domain, X, y, m, self.data.W) table.name = self.data.name table.ids = np.array(self.data.ids) diff --git a/Orange/widgets/data/tests/test_owfile.py b/Orange/widgets/data/tests/test_owfile.py index d655cf11cd6..32affc1bf9e 100644 --- a/Orange/widgets/data/tests/test_owfile.py +++ b/Orange/widgets/data/tests/test_owfile.py @@ -2,8 +2,12 @@ # pylint: disable=missing-docstring from os import path, remove from unittest.mock import Mock +import pickle +import tempfile + import numpy as np +import scipy.sparse as sp from AnyQt.QtCore import QMimeData, QPoint, Qt, QUrl from AnyQt.QtGui import QDragEnterEvent, QDropEvent @@ -195,3 +199,19 @@ def test_check_datetime_disabled(self): for i in range(4): vartype_delegate.setEditorData(combo, idx(i)) self.assertEqual(combo.count(), counts[i]) + + def test_domain_edit_on_sparse_data(self): + iris = Table("iris") + iris.X = sp.csr_matrix(iris.X) + + f = tempfile.NamedTemporaryFile(suffix='.pickle', delete=False) + pickle.dump(iris, f) + f.close() + + self.widget.add_path(f.name) + self.widget.load_data() + + output = self.get_output("Data") + self.assertIsInstance(output, Table) + self.assertEqual(iris.X.shape, output.X.shape) + self.assertTrue(sp.issparse(output.X)) diff --git a/Orange/widgets/utils/domaineditor.py b/Orange/widgets/utils/domaineditor.py index 02ab8c7a005..387e7fe8722 100644 --- a/Orange/widgets/utils/domaineditor.py +++ b/Orange/widgets/utils/domaineditor.py @@ -1,6 +1,7 @@ from itertools import chain import numpy as np +import scipy.sparse as sp from AnyQt.QtCore import Qt, QAbstractTableModel from AnyQt.QtGui import QColor @@ -8,6 +9,7 @@ from Orange.data import DiscreteVariable, ContinuousVariable, StringVariable, \ TimeVariable, Domain +from Orange.statistics.util import unique from Orange.widgets import gui from Orange.widgets.gui import HorizontalGridDelegate from Orange.widgets.settings import ContextSetting @@ -196,6 +198,37 @@ def __init__(self, widget): self.place_delegate = PlaceDelegate(self, VarTableModel.places) self.setItemDelegateForColumn(Column.place, self.place_delegate) + @staticmethod + def _is_missing(x): + return str(x) in ("nan", "") + + @staticmethod + def _iter_vals(x): + """Iterate over values of sparse or dense arrays.""" + for i in range(x.shape[0]): + yield x[i, 0] + + @staticmethod + def _to_column(x, to_sparse, dtype=None): + """Transform list of values to sparse/dense column array.""" + x = np.array(x, dtype=dtype).reshape(-1, 1) + if to_sparse: + x = sp.csc_matrix(x) + return x + + @staticmethod + def _merge(cols, force_dense=False): + if len(cols) == 0: + return None + + all_dense = not any(sp.issparse(c) for c in cols) + if all_dense: + return np.hstack(cols) + if force_dense: + return np.hstack([c.toarray() if sp.issparse(c) else c for c in cols]) + sparse_cols = [c if sp.issparse(c) else sp.csc_matrix(c) for c in cols] + return sp.hstack(sparse_cols).tocsr() + def get_domain(self, domain, data): """Create domain (and dataset) from changes made in the widget. @@ -212,9 +245,6 @@ def get_domain(self, domain, data): places = [[], [], []] # attributes, class_vars, metas cols = [[], [], []] # Xcols, Ycols, Mcols - def is_missing(x): - return str(x) in ("nan", "") - for (name, tpe, place, _, _), (orig_var, orig_plc) in \ zip(variables, chain([(at, Place.feature) for at in domain.attributes], @@ -222,13 +252,9 @@ def is_missing(x): [(mt, Place.meta) for mt in domain.metas])): if place == Place.skip: continue - if orig_plc == Place.meta: - col_data = data[:, orig_var].metas - elif orig_plc == Place.class_var: - col_data = data[:, orig_var].Y - else: - col_data = data[:, orig_var].X - col_data = col_data.ravel() + + col_data = self._get_column(data, orig_var, orig_plc) + is_sparse = sp.issparse(col_data) if name == orig_var.name and tpe == type(orig_var): var = orig_var elif tpe == type(orig_var): @@ -236,20 +262,40 @@ def is_missing(x): orig_var.name = name var = orig_var elif tpe == DiscreteVariable: - values = list(str(i) for i in np.unique(col_data) if not is_missing(i)) + values = list(str(i) for i in unique(col_data) if not self._is_missing(i)) var = tpe(name, values) - col_data = [np.nan if is_missing(x) else values.index(str(x)) - for x in col_data] + col_data = [np.nan if self._is_missing(x) else values.index(str(x)) + for x in self._iter_vals(col_data)] + col_data = self._to_column(col_data, is_sparse) elif tpe == StringVariable and type(orig_var) == DiscreteVariable: var = tpe(name) col_data = [orig_var.repr_val(x) if not np.isnan(x) else "" - for x in col_data] + for x in self._iter_vals(col_data)] + # don't obey sparsity for StringVariable since they are + # in metas which are transformed to dense below + col_data = self._to_column(col_data, False, dtype=object) else: var = tpe(name) places[place].append(var) cols[place].append(col_data) + + # merge columns for X, Y and metas + feats = cols[Place.feature] + X = self._merge(feats) if len(feats) else np.empty((len(data), 0)) + Y = self._merge(cols[Place.class_var], force_dense=True) + m = self._merge(cols[Place.meta], force_dense=True) domain = Domain(*places) - return domain, cols + return domain, [X, Y, m] + + def _get_column(self, data, source_var, source_place): + """ Extract column from data and preserve sparsity. """ + if source_place == Place.meta: + col_data = data[:, source_var].metas + elif source_place == Place.class_var: + col_data = data[:, source_var].Y.reshape(-1, 1) + else: + col_data = data[:, source_var].X + return col_data def set_domain(self, domain): self.variables = self.parse_domain(domain)