From 136dafc96b56ead4a19b224d28586de731587448 Mon Sep 17 00:00:00 2001 From: janezd Date: Fri, 14 Jan 2022 22:55:32 +0100 Subject: [PATCH] Predictions: Allow choosing a target --- Orange/widgets/evaluate/owpredictions.py | 76 +++++++++++++++++++----- 1 file changed, 62 insertions(+), 14 deletions(-) diff --git a/Orange/widgets/evaluate/owpredictions.py b/Orange/widgets/evaluate/owpredictions.py index 470c92b80ad..dc1d3206cc8 100644 --- a/Orange/widgets/evaluate/owpredictions.py +++ b/Orange/widgets/evaluate/owpredictions.py @@ -76,6 +76,9 @@ class Error(OWWidget.Error): #: List of selected class value indices in the `class_values` list selected_classes = settings.ContextSetting([]) selection = settings.Setting([], schema_only=True) + show_scores = settings.Setting(True) + TARGET_AVERAGE = "(Average over classes)" + target_class = settings.ContextSetting(TARGET_AVERAGE) def __init__(self): super().__init__() @@ -84,22 +87,36 @@ def __init__(self): self.predictors = [] # type: List[PredictorSlot] self.class_values = [] # type: List[str] self._delegates = [] + self.scorer_errors = [] self.left_width = 10 self.selection_store = None self.__pending_selection = self.selection - controlBox = gui.vBox(self.controlArea, "Show probabilities for") - - gui.listBox(controlBox, self, "selected_classes", "class_values", - callback=self._update_prediction_delegate, - selectionMode=QListWidget.ExtendedSelection, - sizePolicy=(QSizePolicy.Preferred, QSizePolicy.MinimumExpanding), - sizeHint=QSize(1, 350), - minimumHeight=100) self.reset_button = gui.button( - controlBox, self, "Restore Original Order", + self.controlArea, self, "Restore Original Order", callback=self._reset_order, tooltip="Show rows in the original order") + gui.separator(self.controlArea, 16) + + gui.listBox( + self.controlArea, self, "selected_classes", "class_values", + box="Show probabilities", + callback=self._update_prediction_delegate, + selectionMode=QListWidget.ExtendedSelection, + sizePolicy=(QSizePolicy.Preferred, QSizePolicy.MinimumExpanding), + minimumHeight=100, maximumHeight=150) + + gui.rubber(self.controlArea) + + box = gui.vBox(self.controlArea, "Model Performance") + gui.checkBox( + box, self, "show_scores", "Show perfomance scores", + callback=self._update_score_table_visibility + ) + self.target_selection = gui.comboBox( + box, self, "target_class", items=[], label="Target class:", + sendSelectedValue=True, callback=self._on_target_changed + ) table_opts = dict(horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOn, horizontalScrollMode=QTableView.ScrollPerPixel, @@ -209,15 +226,22 @@ def _set_class_values(self): if value not in class_values: class_values.append(value) + self.target_selection.clear() + self.target_selection.addItem(self.TARGET_AVERAGE) if self.class_var and self.class_var.is_discrete: values = self.class_var.values + self.target_selection.addItems(values) + self.target_selection.box.setVisible(True) self.class_values = sorted( class_values, key=lambda val: val not in values) self.selected_classes = [ i for i, name in enumerate(class_values) if name in values] + self.controls.selected_classes.box.setVisible(True) else: self.class_values = class_values # This assignment updates listview self.selected_classes = [] + self.controls.selected_classes.box.setVisible(False) + self.target_selection.box.setVisible(False) def handleNewSignals(self): # Disconnect the model: the model and the delegate will be inconsistent @@ -231,6 +255,16 @@ def handleNewSignals(self): self._set_errors() self.commit() + def _on_target_changed(self): + self._update_scores() + # The widget doesn't have conditional commits, so we can call this + # If it would have the in the future, I'd still call the same function + # because one commit button has only one dirty flag, and unnecessarily + # making an unconditional commit af a small table is better than + # conditionally commit all predictions. + self._commit_evaluation_results() + + def _call_predictors(self): if not self.data: return @@ -281,10 +315,15 @@ def _call_predictors(self): def _update_scores(self): model = self.score_table.model + if self.class_var and self.class_var.is_discrete \ + and self.target_class != self.TARGET_AVERAGE: + target = self.class_var.values.index(self.target_class) + else: + target = None model.clear() scorers = usable_scorers(self.class_var) if self.class_var else [] self.score_table.update_header(scorers) - errors = [] + self.scorer_errors = errors = [] for pred in self.predictors: results = pred.results if not isinstance(results, Results) or results.predicted is None: @@ -294,7 +333,7 @@ def _update_scores(self): for scorer in scorers: item = QStandardItem() try: - score = scorer_caller(scorer, results)()[0] + score = scorer_caller(scorer, results, target=target)()[0] item.setText(f"{score:.3f}") except Exception as exc: # pylint: disable=broad-except item.setToolTip(str(exc)) @@ -304,17 +343,26 @@ def _update_scores(self): row.append(item) self.score_table.model.appendRow(row) + self._update_score_table_visibility() + + def _update_score_table_visibility(self): view = self.score_table.view - if model.rowCount(): + nmodels = self.score_table.model.rowCount() + if nmodels and self.show_scores: view.setVisible(True) view.ensurePolished() + view.resizeColumnsToContents() + view.resizeRowsToContents() view.setFixedHeight( 5 + view.horizontalHeader().height() + - view.verticalHeader().sectionSize(0) * model.rowCount()) + view.verticalHeader().sectionSize(0) * nmodels) + + errors = "\n".join(self.scorer_errors) + self.Error.scorer_failed(errors, shown=bool(errors)) else: view.setVisible(False) + self.Error.scorer_failed.clear() - self.Error.scorer_failed("\n".join(errors), shown=bool(errors)) def _set_errors(self): # Not all predictors are run every time, so errors can't be collected