diff --git a/Orange/widgets/unsupervised/owcorrespondence.py b/Orange/widgets/unsupervised/owcorrespondence.py index 965f1621e00..509fec28379 100644 --- a/Orange/widgets/unsupervised/owcorrespondence.py +++ b/Orange/widgets/unsupervised/owcorrespondence.py @@ -8,6 +8,7 @@ import pyqtgraph as pg import Orange.data +from Orange.data import Table, Domain, ContinuousVariable, StringVariable from Orange.statistics import contingency from Orange.widgets import widget, gui, settings @@ -15,7 +16,8 @@ from Orange.widgets.utils.widgetpreview import WidgetPreview from Orange.widgets.visualize.owscatterplotgraph import ScatterPlotItem -from Orange.widgets.widget import Input +from Orange.widgets.widget import Input, Output +from orangewidget.settings import Setting class ScatterPlotItem(pg.ScatterPlotItem): @@ -51,11 +53,15 @@ class OWCorrespondenceAnalysis(widget.OWWidget): class Inputs: data = Input("Data", Orange.data.Table) + class Outputs: + coordinates = Output("Coordinates", Orange.data.Table) + Invalidate = QEvent.registerEventType() settingsHandler = settings.DomainContextHandler() selected_var_indices = settings.ContextSetting([]) + auto_commit = Setting(True) graph_name = "plot.plotItem" @@ -96,6 +102,8 @@ def __init__(self): gui.vBox(self.controlArea, "Contribution to Inertia"), "\n" ) + gui.auto_send(self.controlArea, self, "auto_commit") + gui.rubber(self.controlArea) self.plot = pg.PlotWidget(background="w") @@ -127,6 +135,24 @@ def set_data(self, data): self._restore_selection() self._update_CA() + def commit(self): + output_table = None + if self.ca is not None: + sel_vars = self.selected_vars() + if len(sel_vars) == 2: + rf = np.vstack((self.ca.row_factors, self.ca.col_factors)) + else: + rf = self.ca.row_factors + vars = [(val.name, var) for val in sel_vars for var in val.values] + output_table = Table( + Domain([ContinuousVariable(f"Component {i + 1}") + for i in range(rf.shape[1])], + metas=[StringVariable("Variable"), + StringVariable("Value")]), + rf, metas=vars + ) + self.Outputs.coordinates.send(output_table) + def clear(self): self.data = None self.ca = None @@ -145,8 +171,7 @@ def restore(view, indices): restore(self.varview, self.selected_var_indices) def _p_axes(self): -# return (0, 1) - return (self.component_x, self.component_y) + return self.component_x, self.component_y def _var_changed(self): self.selected_var_indices = sorted( @@ -182,6 +207,7 @@ def _update_CA(self): self._setup_plot() self._update_info() + self.commit() def update_XY(self): self.axis_x_cb.clear() @@ -406,4 +432,4 @@ def inertia_of_axis(self): if __name__ == "__main__": # pragma: no cover - WidgetPreview(OWCorrespondenceAnalysis).run(Orange.data.Table("smokers_ct")) + WidgetPreview(OWCorrespondenceAnalysis).run(Orange.data.Table("titanic")) diff --git a/Orange/widgets/unsupervised/tests/test_owcorrespondence.py b/Orange/widgets/unsupervised/tests/test_owcorrespondence.py index 4a912110fa4..80149d41684 100644 --- a/Orange/widgets/unsupervised/tests/test_owcorrespondence.py +++ b/Orange/widgets/unsupervised/tests/test_owcorrespondence.py @@ -3,12 +3,14 @@ from Orange.data import Table, Domain, DiscreteVariable, ContinuousVariable from Orange.widgets.tests.base import WidgetTest from Orange.widgets.unsupervised.owcorrespondence \ - import OWCorrespondenceAnalysis + import OWCorrespondenceAnalysis, select_rows +from Orange.widgets.utils import itemmodels class TestOWCorrespondence(WidgetTest): def setUp(self): self.widget = self.create_widget(OWCorrespondenceAnalysis) + self.data = Table("Titanic") def test_no_data(self): """Check that the widget doesn't crash on empty data""" @@ -73,3 +75,17 @@ def test_no_discrete_variables(self): self.assertTrue(self.widget.Error.no_disc_vars.is_shown()) self.send_signal(self.widget.Inputs.data, Table("iris")) self.assertFalse(self.widget.Error.no_disc_vars.is_shown()) + + def test_outputs(self): + w = self.widget + + self.assertIsNone(self.get_output(w.Outputs.coordinates), None) + self.send_signal(self.widget.Inputs.data, self.data) + self.assertTupleEqual(self.get_output(w.Outputs.coordinates).X.shape, + (6, 2)) + select_rows(w.varview, [0, 1, 2]) + w.commit() + self.assertTupleEqual(self.get_output(w.Outputs.coordinates).X.shape, + (8, 8)) + self.send_signal(self.widget.Inputs.data, None) + self.assertIsNone(self.get_output(w.Outputs.coordinates), None)