Skip to content

Commit

Permalink
Predictions: Allow choosing a target
Browse files Browse the repository at this point in the history
  • Loading branch information
janezd committed Jan 14, 2022
1 parent c860359 commit 136dafc
Showing 1 changed file with 62 additions and 14 deletions.
76 changes: 62 additions & 14 deletions Orange/widgets/evaluate/owpredictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ class Error(OWWidget.Error):
#: List of selected class value indices in the `class_values` list
selected_classes = settings.ContextSetting([])
selection = settings.Setting([], schema_only=True)
show_scores = settings.Setting(True)
TARGET_AVERAGE = "(Average over classes)"
target_class = settings.ContextSetting(TARGET_AVERAGE)

def __init__(self):
super().__init__()
Expand All @@ -84,22 +87,36 @@ def __init__(self):
self.predictors = [] # type: List[PredictorSlot]
self.class_values = [] # type: List[str]
self._delegates = []
self.scorer_errors = []
self.left_width = 10
self.selection_store = None
self.__pending_selection = self.selection

controlBox = gui.vBox(self.controlArea, "Show probabilities for")

gui.listBox(controlBox, self, "selected_classes", "class_values",
callback=self._update_prediction_delegate,
selectionMode=QListWidget.ExtendedSelection,
sizePolicy=(QSizePolicy.Preferred, QSizePolicy.MinimumExpanding),
sizeHint=QSize(1, 350),
minimumHeight=100)
self.reset_button = gui.button(
controlBox, self, "Restore Original Order",
self.controlArea, self, "Restore Original Order",
callback=self._reset_order,
tooltip="Show rows in the original order")
gui.separator(self.controlArea, 16)

gui.listBox(
self.controlArea, self, "selected_classes", "class_values",
box="Show probabilities",
callback=self._update_prediction_delegate,
selectionMode=QListWidget.ExtendedSelection,
sizePolicy=(QSizePolicy.Preferred, QSizePolicy.MinimumExpanding),
minimumHeight=100, maximumHeight=150)

gui.rubber(self.controlArea)

box = gui.vBox(self.controlArea, "Model Performance")
gui.checkBox(
box, self, "show_scores", "Show perfomance scores",
callback=self._update_score_table_visibility
)
self.target_selection = gui.comboBox(
box, self, "target_class", items=[], label="Target class:",
sendSelectedValue=True, callback=self._on_target_changed
)

table_opts = dict(horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
horizontalScrollMode=QTableView.ScrollPerPixel,
Expand Down Expand Up @@ -209,15 +226,22 @@ def _set_class_values(self):
if value not in class_values:
class_values.append(value)

self.target_selection.clear()
self.target_selection.addItem(self.TARGET_AVERAGE)
if self.class_var and self.class_var.is_discrete:
values = self.class_var.values
self.target_selection.addItems(values)
self.target_selection.box.setVisible(True)
self.class_values = sorted(
class_values, key=lambda val: val not in values)
self.selected_classes = [
i for i, name in enumerate(class_values) if name in values]
self.controls.selected_classes.box.setVisible(True)
else:
self.class_values = class_values # This assignment updates listview
self.selected_classes = []
self.controls.selected_classes.box.setVisible(False)
self.target_selection.box.setVisible(False)

def handleNewSignals(self):
# Disconnect the model: the model and the delegate will be inconsistent
Expand All @@ -231,6 +255,16 @@ def handleNewSignals(self):
self._set_errors()
self.commit()

def _on_target_changed(self):
self._update_scores()
# The widget doesn't have conditional commits, so we can call this
# If it would have the in the future, I'd still call the same function
# because one commit button has only one dirty flag, and unnecessarily
# making an unconditional commit af a small table is better than
# conditionally commit all predictions.
self._commit_evaluation_results()


def _call_predictors(self):
if not self.data:
return
Expand Down Expand Up @@ -281,10 +315,15 @@ def _call_predictors(self):

def _update_scores(self):
model = self.score_table.model
if self.class_var and self.class_var.is_discrete \
and self.target_class != self.TARGET_AVERAGE:
target = self.class_var.values.index(self.target_class)
else:
target = None
model.clear()
scorers = usable_scorers(self.class_var) if self.class_var else []
self.score_table.update_header(scorers)
errors = []
self.scorer_errors = errors = []
for pred in self.predictors:
results = pred.results
if not isinstance(results, Results) or results.predicted is None:
Expand All @@ -294,7 +333,7 @@ def _update_scores(self):
for scorer in scorers:
item = QStandardItem()
try:
score = scorer_caller(scorer, results)()[0]
score = scorer_caller(scorer, results, target=target)()[0]
item.setText(f"{score:.3f}")
except Exception as exc: # pylint: disable=broad-except
item.setToolTip(str(exc))
Expand All @@ -304,17 +343,26 @@ def _update_scores(self):
row.append(item)
self.score_table.model.appendRow(row)

self._update_score_table_visibility()

def _update_score_table_visibility(self):
view = self.score_table.view
if model.rowCount():
nmodels = self.score_table.model.rowCount()
if nmodels and self.show_scores:
view.setVisible(True)
view.ensurePolished()
view.resizeColumnsToContents()
view.resizeRowsToContents()
view.setFixedHeight(
5 + view.horizontalHeader().height() +
view.verticalHeader().sectionSize(0) * model.rowCount())
view.verticalHeader().sectionSize(0) * nmodels)

errors = "\n".join(self.scorer_errors)
self.Error.scorer_failed(errors, shown=bool(errors))
else:
view.setVisible(False)
self.Error.scorer_failed.clear()

self.Error.scorer_failed("\n".join(errors), shown=bool(errors))

def _set_errors(self):
# Not all predictors are run every time, so errors can't be collected
Expand Down

0 comments on commit 136dafc

Please sign in to comment.