From 38aadcecabca05f01cceaddce189a8a301d22ce9 Mon Sep 17 00:00:00 2001 From: janezd Date: Thu, 21 May 2020 14:48:56 +0200 Subject: [PATCH] Select Rows: Fix incorrectly stored values in settings --- Orange/widgets/data/owselectrows.py | 145 ++++++++-------- .../widgets/data/tests/test_owselectrows.py | 155 ++++++++++++++++-- 2 files changed, 223 insertions(+), 77 deletions(-) diff --git a/Orange/widgets/data/owselectrows.py b/Orange/widgets/data/owselectrows.py index 4022330dee5..2b07394d829 100644 --- a/Orange/widgets/data/owselectrows.py +++ b/Orange/widgets/data/owselectrows.py @@ -1,6 +1,5 @@ import enum from collections import OrderedDict -from itertools import chain import numpy as np @@ -13,16 +12,16 @@ QFontMetrics, QPalette ) from AnyQt.QtCore import Qt, QPoint, QRegExp, QPersistentModelIndex, QLocale + +from Orange.widgets.utils.itemmodels import DomainModel from orangewidget.utils.combobox import ComboBoxSearch from Orange.data import ( - Variable, ContinuousVariable, DiscreteVariable, StringVariable, - TimeVariable, + ContinuousVariable, DiscreteVariable, StringVariable, TimeVariable, Table ) import Orange.data.filter as data_filter from Orange.data.filter import FilterContinuous, FilterString -from Orange.data.domain import filter_visible from Orange.data.sql.table import SqlTable from Orange.preprocess import Remove from Orange.widgets import widget, gui @@ -52,24 +51,49 @@ def encode_setting(self, context, setting, value): encoded = [] CONTINUOUS = vartype(ContinuousVariable("x")) for attr, op, values in value: - vtype = context.attributes.get(attr) - if vtype == CONTINUOUS and values and isinstance(values[0], str): - values = [QLocale().toDouble(v)[0] for v in values] - encoded.append((attr, vtype, op, values)) + if isinstance(attr, str): + if OWSelectRows.AllTypes.get(attr) == CONTINUOUS: + values = [QLocale().toDouble(v)[0] for v in values] + # None will match the value returned by all_vars.get + encoded.append((attr, None, op, values)) + else: + if type(attr) is ContinuousVariable \ + and values and isinstance(values[0], str): + values = [QLocale().toDouble(v)[0] for v in values] + elif isinstance(attr, DiscreteVariable): + values = [attr.values[i - 1] if i else "" for i in values] + encoded.append( + (attr.name, context.attributes.get(attr.name), op, values)) return encoded def decode_setting(self, setting, value, domain=None): value = super().decode_setting(setting, value, domain) if setting.name == 'conditions': + CONTINUOUS = vartype(ContinuousVariable("x")) # Use this after 2022/2/2: - # for i, (attr, _, op, values) in enumerate(value): - for i, condition in enumerate(value): - attr = condition[0] - op, values = condition[-2:] - - var = attr in domain and domain[attr] - if var and var.is_continuous and not isinstance(var, TimeVariable): + # for i, (attr, tpe, op, values) in enumerate(value): + # if tpe is not None: + for i, (attr, *tpe, op, values) in enumerate(value): + if tpe != [None] \ + or not tpe and attr not in OWSelectRows.AllTypes: + attr = domain[attr] + if type(attr) is ContinuousVariable \ + or OWSelectRows.AllTypes.get(attr) == CONTINUOUS: values = [QLocale().toString(float(i), 'f') for i in values] + elif isinstance(attr, DiscreteVariable): + # After 2022/2/2, use just the expression in else clause + if values and isinstance(values[0], int): + # Backwards compatibility. Reset setting if we detect + # that the number of values decreased. Still broken if + # they're reordered or we don't detect the decrease. + # + # indices start with 1, thus >, not >= + if max(values) > len(attr.values): + values = (0, ) + else: + values = tuple(attr.to_val(val) + 1 if val else 0 + for val in values if val in attr.values) \ + or (0, ) value[i] = (attr, op, values) return value @@ -80,7 +104,7 @@ def match(self, context, domain, attrs, metas): conditions = context.values["conditions"] all_vars = attrs.copy() all_vars.update(metas) - matched = [all_vars.get(name) == tpe + matched = [all_vars.get(name) == tpe # also matches "all (...)" strings # After 2022/2/2 remove this line: if len(rest) == 2 else name in all_vars for name, tpe, *rest in conditions] @@ -101,6 +125,8 @@ def filter_value(self, setting, data, domain, attrs, metas): # if all_vars.get(name) == tpe] conditions[:] = [ (name, tpe, *rest) for name, tpe, *rest in conditions + # all_vars.get(name) == tpe also matches "all (...)" which are + # encoded with type `None` if (all_vars.get(name) == tpe if len(rest) == 2 else name in all_vars)] @@ -209,6 +235,9 @@ def __init__(self): self.last_output_conditions = None self.data = None self.data_desc = self.match_desc = self.nonmatch_desc = None + self.variable_model = DomainModel( + [list(self.AllTypes), DomainModel.Separator, + DomainModel.CLASSES, DomainModel.ATTRIBUTES, DomainModel.METAS]) box = gui.vBox(self.controlArea, 'Conditions', stretch=100) self.cond_list = QTableWidget( @@ -268,18 +297,11 @@ def add_row(self, attr=None, condition_type=None, condition_value=None): attr_combo = ComboBoxSearch( minimumContentsLength=12, sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLengthWithIcon) + attr_combo.setModel(self.variable_model) attr_combo.row = row - for var in self._visible_variables(self.data.domain): - if isinstance(var, Variable): - attr_combo.addItem(*gui.attributeItem(var)) - else: - attr_combo.addItem(var) - if isinstance(attr, str): - attr_combo.setCurrentText(attr) - else: - attr_combo.setCurrentIndex( - attr or - len(self.AllTypes) - (attr_combo.count() == len(self.AllTypes))) + attr_combo.setCurrentIndex(self.variable_model.indexOf(attr) if attr + else len(self.AllTypes) + 1) + self.cond_list.setCellWidget(row, 0, attr_combo) index = QPersistentModelIndex(model.index(row, 3)) @@ -297,15 +319,6 @@ def add_row(self, attr=None, condition_type=None, condition_value=None): self.cond_list.resizeRowToContents(row) - @classmethod - def _visible_variables(cls, domain): - """Generate variables in order they should be presented in in combos.""" - return chain( - cls.AllTypes, - filter_visible(chain(domain.class_vars, - domain.metas, - domain.attributes))) - def add_all(self): if self.cond_list.rowCount(): Mb = QMessageBox @@ -315,9 +328,9 @@ def add_all(self): "filters for all variables.", Mb.Ok | Mb.Cancel) != Mb.Ok: return self.remove_all() - domain = self.data.domain - for i in range(len(domain.variables) + len(domain.metas)): - self.add_row(i) + for attr in self.variable_model[len(self.AllTypes) + 1:]: + self.add_row(attr) + self.conditions_changed() def remove_one(self, rownum): self.remove_one_row(rownum) @@ -333,6 +346,12 @@ def remove_one_row(self, rownum): self.remove_all_button.setDisabled(True) def remove_all_rows(self): + # Disconnect signals to avoid stray emits when changing variable_model + for row in range(self.cond_list.rowCount()): + for col in (0, 1): + widget = self.cond_list.cellWidget(row, col) + if widget: + widget.currentIndexChanged.disconnect() self.cond_list.clear() self.cond_list.setRowCount(0) self.remove_all_button.setDisabled(True) @@ -495,24 +514,18 @@ def set_data(self, data): if not data: self.info.set_input_summary(self.info.NoInput) self.data_desc = None + self.variable_model.set_domain(None) self.commit() return self.data_desc = report.describe_data_brief(data) - self.conditions = [] - try: - self.openContext(data) - except Exception: - pass + self.variable_model.set_domain(data.domain) - variables = list(self._visible_variables(self.data.domain)) - varnames = [v.name if isinstance(v, Variable) else v for v in variables] - if self.conditions: - for attr, cond_type, cond_value in self.conditions: - if attr in varnames: - self.add_row(varnames.index(attr), cond_type, cond_value) - elif attr in self.AllTypes: - self.add_row(attr, cond_type, cond_value) - else: + self.conditions = [] + self.openContext(data) + for attr, cond_type, cond_value in self.conditions: + if attr in self.variable_model: + self.add_row(attr, cond_type, cond_value) + if not self.cond_list.model().rowCount(): self.add_row() self.info.set_input_summary(data.approx_len(), @@ -521,12 +534,15 @@ def set_data(self, data): def conditions_changed(self): try: - self.conditions = [] + cells_by_rows = ( + [self.cond_list.cellWidget(row, col) for col in range(3)] + for row in range(self.cond_list.rowCount()) + ) self.conditions = [ - (self.cond_list.cellWidget(row, 0).currentText(), - self.cond_list.cellWidget(row, 1).currentIndex(), - self._get_value_contents(self.cond_list.cellWidget(row, 2))) - for row in range(self.cond_list.rowCount())] + (var_cell.currentData(gui.TableVariable) or var_cell.currentText(), + oper_cell.currentIndex(), + self._get_value_contents(val_cell)) + for var_cell, oper_cell, val_cell in cells_by_rows] if self.update_on_change and ( self.last_output_conditions is None or self.last_output_conditions != self.conditions): @@ -674,19 +690,18 @@ def send_report(self): pdesc = ndesc conditions = [] - domain = self.data.domain - for attr_name, oper, values in self.conditions: - if attr_name in self.AllTypes: - attr = attr_name + for attr, oper, values in self.conditions: + if isinstance(attr, str): + attr_name = attr + var_type = self.AllTypes[attr] names = self.operator_names[attr_name] - var_type = self.AllTypes[attr_name] else: - attr = domain[attr_name] + attr_name = attr.name var_type = vartype(attr) names = self.operator_names[type(attr)] name = names[oper] if oper == len(names) - 1: - conditions.append("{} {}".format(attr, name)) + conditions.append("{} {}".format(attr_name, name)) elif var_type == 1: # discrete if name == "is one of": valnames = [attr.values[v - 1] for v in values] diff --git a/Orange/widgets/data/tests/test_owselectrows.py b/Orange/widgets/data/tests/test_owselectrows.py index df0d6b4d9dd..a69e3ba9295 100644 --- a/Orange/widgets/data/tests/test_owselectrows.py +++ b/Orange/widgets/data/tests/test_owselectrows.py @@ -1,7 +1,7 @@ # Test methods with long descriptive names can omit docstrings # pylint: disable=missing-docstring,unsubscriptable-object import time -from unittest.mock import Mock +from unittest.mock import Mock, patch from AnyQt.QtCore import QLocale, Qt from AnyQt.QtTest import QTest @@ -10,8 +10,10 @@ import numpy as np from Orange.data import ( - Table, ContinuousVariable, StringVariable, DiscreteVariable, Domain) + Table, Variable, ContinuousVariable, StringVariable, DiscreteVariable, + Domain) from Orange.preprocess import discretize +from Orange.widgets.data import owselectrows from Orange.widgets.data.owselectrows import ( OWSelectRows, FilterDiscreteType, SelectRowsContextHandler) from Orange.widgets.tests.base import WidgetTest, datasets @@ -70,7 +72,7 @@ def test_filter_cont(self): for i, (op, _) in enumerate(OWSelectRows.Operators[ContinuousVariable]): self.widget.remove_all() - self.widget.add_row(1, i, CFValues[op]) + self.widget.add_row(iris.domain[0], i, CFValues[op]) self.widget.conditions_changed() self.widget.unconditional_commit() @@ -80,7 +82,7 @@ def test_filter_str(self): self.widget.set_data(zoo) for i, (op, _) in enumerate(OWSelectRows.Operators[StringVariable]): self.widget.remove_all() - self.widget.add_row(1, i, SFValues[op]) + self.widget.add_row(zoo.domain.metas[0], i, SFValues[op]) self.widget.conditions_changed() self.widget.unconditional_commit() @@ -125,6 +127,22 @@ def test_continuous_filter_with_sl_SI_locale(self): self.enterFilter(iris.domain[2], "is below", "5.2") self.assertEqual(self.widget.conditions[0][2], ("52",)) + @override_locale(QLocale.C) # Locale with decimal point + def test_all_numeric_filter_with_c_locale_from_context(self): + iris = Table("iris")[:5] + widget = self.widget_with_context( + iris.domain, [["All numeric variables", None, 0, (3.14, )]]) + self.send_signal(widget.Inputs.data, iris) + self.assertTrue(widget.conditions[0][2][0].startswith("3.14")) + + @override_locale(QLocale.Slovenian) # Locale with decimal comma + def test_all_numeric_filter_with_sl_SI_locale(self): + iris = Table("iris")[:5] + widget = self.widget_with_context( + iris.domain, [["All numeric variables", None, 0, (3.14, )]]) + self.send_signal(widget.Inputs.data, iris) + self.assertTrue(widget.conditions[0][2][0].startswith("3,14")) + @override_locale(QLocale.Slovenian) def test_stores_settings_in_invariant_locale(self): iris = Table("iris")[:5] @@ -140,6 +158,26 @@ def test_stores_settings_in_invariant_locale(self): saved_condition = context.values["conditions"][0] self.assertEqual(saved_condition[3][0], 5.2) + @override_locale(QLocale.C) # Locale with decimal point + def test_store_all_numeric_filter_with_c_locale_to_context(self): + iris = Table("iris")[:5] + self.send_signal(self.widget.Inputs.data, iris) + self.widget.remove_all_button.click() + self.enterFilter("All numeric variables", "equal", "3.14") + context = self.widget.current_context + self.send_signal(self.widget.Inputs.data, None) + self.assertEqual(context.values["conditions"][0][3], [3.14]) + + @override_locale(QLocale.Slovenian) # Locale with decimal comma + def test_store_all_numeric_filter_with_sl_SI_locale_to_context(self): + iris = Table("iris")[:5] + self.send_signal(self.widget.Inputs.data, iris) + self.widget.remove_all_button.click() + self.enterFilter("All numeric variables", "equal", "3,14") + context = self.widget.current_context + self.send_signal(self.widget.Inputs.data, None) + self.assertEqual(context.values["conditions"][0][3], [3.14]) + @override_locale(QLocale.C) def test_restores_continuous_filter_in_c_locale(self): iris = Table("iris")[:5] @@ -183,25 +221,82 @@ def test_partial_matches(self): iris = Table("iris") domain = iris.domain self.widget = self.widget_with_context( - domain, [[domain[0].name, 2, ("5.2",)]]) + domain, [[domain[0].name, 2, 2, ("5.2",)]]) iris2 = iris.transform(Domain(domain.attributes[:2], None)) self.send_signal(self.widget.Inputs.data, iris2) condition = self.widget.conditions[0] - self.assertEqual(condition[0], "sepal length") + self.assertEqual(condition[0], iris.domain[0]) self.assertEqual(condition[1], 2) self.assertTrue(condition[2][0].startswith("5.2")) + def test_partial_match_values(self): + iris = Table("iris") + domain = iris.domain + class_var = domain.class_var + self.widget = self.widget_with_context( + domain, [[class_var.name, 1, 2, + (class_var.values[0], class_var.values[2])]]) + + # sanity checks + self.send_signal(self.widget.Inputs.data, iris) + condition = self.widget.conditions[0] + self.assertIs(condition[0], class_var) + self.assertEqual(condition[1], 2) + self.assertEqual(condition[2], (1, 3)) # indices of values + 1 + + # actual test + new_class_var = DiscreteVariable(class_var.name, class_var.values[1:]) + new_domain = Domain(domain.attributes, new_class_var) + non0 = iris.Y != 0 + iris2 = Table.from_numpy(new_domain, iris.X[non0], iris.Y[non0] - 1) + self.send_signal(self.widget.Inputs.data, iris2) + condition = self.widget.conditions[0] + self.assertIs(condition[0], new_class_var) + self.assertEqual(condition[1], 2) + self.assertEqual(condition[2], (2, )) # index of value + 1 + + def test_backward_compat_match_values(self): + iris = Table("iris") + domain = iris.domain + class_var = domain.class_var + self.widget = self.widget_with_context( + domain, [[class_var.name, 1, 2, (1, 2)]]) + + new_class_var = DiscreteVariable(class_var.name, class_var.values[1:]) + new_domain = Domain(domain.attributes, new_class_var) + non0 = iris.Y != 0 + iris2 = Table.from_numpy(new_domain, iris.X[non0], iris.Y[non0] - 1) + self.send_signal(self.widget.Inputs.data, iris2) + condition = self.widget.conditions[0] + self.assertIs(condition[0], new_class_var) + self.assertEqual(condition[1], 2) + self.assertEqual(condition[2], (1, 2)) # index of value + 1 + + # reset to [0] if out of range + self.widget = self.widget_with_context( + domain, [[class_var.name, 1, 2, (1, 3)]]) + + new_class_var = DiscreteVariable(class_var.name, class_var.values[1:]) + new_domain = Domain(domain.attributes, new_class_var) + non0 = iris.Y != 0 + iris2 = Table.from_numpy(new_domain, iris.X[non0], iris.Y[non0] - 1) + self.send_signal(self.widget.Inputs.data, iris2) + condition = self.widget.conditions[0] + self.assertIs(condition[0], new_class_var) + self.assertEqual(condition[1], 2) + self.assertEqual(condition[2], (0, )) # index of value + 1 + @override_locale(QLocale.C) def test_partial_matches_with_missing_vars(self): iris = Table("iris") domain = iris.domain self.widget = self.widget_with_context( - domain, [[domain[0].name, 2, ("5.2",)], - [domain[2].name, 2, ("4.2",)]]) + domain, [[domain[0].name, 2, 2, ("5.2",)], + [domain[2].name, 2, 2, ("4.2",)]]) iris2 = iris.transform(Domain(domain.attributes[2:], None)) self.send_signal(self.widget.Inputs.data, iris2) condition = self.widget.conditions[0] - self.assertEqual(condition[0], domain[2].name) + self.assertEqual(condition[0], domain[2]) self.assertEqual(condition[1], 2) self.assertTrue(condition[2][0].startswith("4.2")) @@ -335,6 +430,41 @@ def test_keep_operator(self): self.assertEqual( self.widget.cond_list.cellWidget(0, 1).currentText(), "is") + @patch.object(owselectrows.QMessageBox, "question", + return_value=owselectrows.QMessageBox.Ok) + def test_add_all(self, msgbox): + iris = Table("iris") + domain = iris.domain + self.send_signal(self.widget.Inputs.data, iris) + self.widget.add_all_button.click() + msgbox.assert_called() + self.assertEqual([cond[0] for cond in self.widget.conditions], + list(domain.class_vars + domain.attributes)) + + @patch.object(owselectrows.QMessageBox, "question", + return_value=owselectrows.QMessageBox.Cancel) + def test_add_all_cancel(self, msgbox): + iris = Table("iris") + domain = iris.domain + self.send_signal(self.widget.Inputs.data, iris) + self.assertEqual([cond[0] for cond in self.widget.conditions], + list(domain.class_vars)) + self.widget.add_all_button.click() + msgbox.assert_called() + self.assertEqual([cond[0] for cond in self.widget.conditions], + list(domain.class_vars)) + + @patch.object(owselectrows.QMessageBox, "question", + return_value=owselectrows.QMessageBox.Ok) + def test_report(self, _): + zoo = Table("zoo") + self.send_signal(self.widget.Inputs.data, zoo) + self.widget.add_all_button.click() + self.enterFilter("All numeric variables", "equal", "42") + self.enterFilter(zoo.domain[0], "is defined") + self.enterFilter(zoo.domain[1], "is one of") + self.widget.send_report() # don't crash + # Uncomment this on 2022/2/2 # # def test_migration_to_version_1(self): @@ -354,7 +484,7 @@ def test_support_old_settings(self): iris.domain, [["sepal length", 2, ("5.2",)]]) self.send_signal(self.widget.Inputs.data, iris) condition = self.widget.conditions[0] - self.assertEqual(condition[0], "sepal length") + self.assertEqual(condition[0], iris.domain["sepal length"]) self.assertEqual(condition[1], 2) self.assertTrue(condition[2][0].startswith("5.2")) @@ -380,7 +510,7 @@ def test_purge_discretized(self): discretize_class=True, method=method) domain = discretizer(housing) data = housing.transform(domain) - widget = self.widget_with_context(domain, [["MEDV", 2, (2, 3)]]) + widget = self.widget_with_context(domain, [["MEDV", 101, 2, (2, 3)]]) widget.purge_classes = True self.send_signal(widget.Inputs.data, data) out = self.get_output(widget.Outputs.matching_data) @@ -403,7 +533,8 @@ def enterFilter(self, variable, filter, value1=None, value2=None): self.widget.add_button.click() var_combo = self.widget.cond_list.cellWidget(row, 0) - simulate.combobox_activate_item(var_combo, variable.name, delay=0) + name = variable.name if isinstance(variable, Variable) else variable + simulate.combobox_activate_item(var_combo, name, delay=0) oper_combo = self.widget.cond_list.cellWidget(row, 1) simulate.combobox_activate_item(oper_combo, filter, delay=0)