Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Test & Score: Add comparison of models #4261

Merged
merged 5 commits into from
Jan 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 209 additions & 6 deletions Orange/widgets/evaluate/owtestlearners.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# pylint doesn't understand the Settings magic
# pylint: disable=invalid-sequence-index
# pylint: disable=too-many-lines,too-many-instance-attributes
import abc
import enum
import logging
Expand All @@ -9,14 +10,17 @@

from concurrent.futures import Future
from collections import OrderedDict, namedtuple
from itertools import count
from typing import Any, Optional, List, Dict, Callable

import numpy as np
import baycomp

from AnyQt import QtGui
from AnyQt.QtGui import QStandardItem
from AnyQt.QtCore import Qt, QSize, QThread
from AnyQt.QtCore import pyqtSlot as Slot
from AnyQt.QtGui import QStandardItem, QDoubleValidator
from AnyQt.QtWidgets import QHeaderView, QTableWidget, QLabel

from Orange.base import Learner
import Orange.classification
Expand All @@ -35,7 +39,7 @@
from Orange.widgets.utils.widgetpreview import WidgetPreview
from Orange.widgets.utils.concurrent import ThreadExecutor, TaskState
from Orange.widgets.widget import OWWidget, Msg, Input, Output

from orangewidget.utils.itemmodels import PyListModel

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -175,6 +179,10 @@ class Outputs:
fold_feature = settings.ContextSetting(None)
fold_feature_selected = settings.ContextSetting(False)

use_rope = settings.Setting(False)
rope = settings.Setting(0.1)
comparison_criterion = settings.Setting(0, schema_only=True)

TARGET_AVERAGE = "(Average over classes)"
class_selection = settings.ContextSetting(TARGET_AVERAGE)

Expand Down Expand Up @@ -216,6 +224,7 @@ def __init__(self):
self.train_data_missing_vals = False
self.test_data_missing_vals = False
self.scorers = []
self.__pending_comparison_criterion = self.comparison_criterion

#: An Ordered dictionary with current inputs and their testing results.
self.learners = OrderedDict() # type: Dict[Any, Input]
Expand Down Expand Up @@ -275,13 +284,55 @@ def __init__(self):
callback=self._on_target_class_changed,
contentsLength=8)

self.modcompbox = box = gui.vBox(self.controlArea, "Model Comparison")
gui.comboBox(
box, self, "comparison_criterion", model=PyListModel(),
callback=self.update_comparison_table)

hbox = gui.hBox(box)
gui.checkBox(hbox, self, "use_rope",
"Negligible difference: ",
callback=self._on_use_rope_changed)
gui.lineEdit(hbox, self, "rope", validator=QDoubleValidator(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this not a spinbox?
It should probably disabled when use_rope is not checked.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not a spin box because it has no defined range. It can refer to AUC, that is, between 0 and 1, or it can be RMSE, which is between 0 and infinity -- it can easily be 100000.

Disabling it would make sense, I'll do that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added disabling but didn't like it. Let's let the user change the line edit first and then enable it, if (s)he wishes.

I added a method _on_use_rope_changed. You can add a line self.controls.rope.setEnabled(self.use_rope) and see for yourself that you won't like it. :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a won't fix. :)

controlWidth=70, callback=self.update_comparison_table,
alignment=Qt.AlignRight)
self.controls.rope.setEnabled(self.use_rope)

gui.rubber(self.controlArea)
self.score_table = ScoreTable(self)
self.score_table.shownScoresChanged.connect(self.update_stats_model)
view = self.score_table.view
view.setSizeAdjustPolicy(view.AdjustToContents)

box = gui.vBox(self.mainArea, "Evaluation Results")
box.layout().addWidget(self.score_table.view)

self.compbox = box = gui.vBox(self.mainArea, box="Model comparison")
table = self.comparison_table = QTableWidget(
wordWrap=False, editTriggers=QTableWidget.NoEditTriggers,
selectionMode=QTableWidget.NoSelection)
table.setSizeAdjustPolicy(table.AdjustToContents)
header = table.verticalHeader()
header.setSectionResizeMode(QHeaderView.Fixed)
header.setSectionsClickable(False)

header = table.horizontalHeader()
header.setTextElideMode(Qt.ElideRight)
header.setDefaultAlignment(Qt.AlignCenter)
header.setSectionsClickable(False)
header.setStretchLastSection(False)
header.setSectionResizeMode(QHeaderView.ResizeToContents)
avg_width = self.fontMetrics().averageCharWidth()
header.setMinimumSectionSize(8 * avg_width)
header.setMaximumSectionSize(15 * avg_width)
header.setDefaultSectionSize(15 * avg_width)
box.layout().addWidget(table)
box.layout().addWidget(QLabel(
"<small>Table shows probabilities that the score for the model in "
"the row is higher than that of the model in the column. "
"Small numbers show the probability that the difference is "
"negligible.</small>", wordWrap=True))

@staticmethod
def sizeHint():
return QSize(780, 1)
Expand Down Expand Up @@ -436,10 +487,32 @@ def _which_missing_data(self):
# - we don't gain much with it
# - it complicates the unit tests
def _update_scorers(self):
if self.data is None or self.data.domain.class_var is None:
self.scorers = []
return
self.scorers = usable_scorers(self.data.domain.class_var)
if self.data and self.data.domain.class_var:
new_scorers = usable_scorers(self.data.domain.class_var)
else:
new_scorers = []
# Don't unnecessarily reset the model because this would always reset
# comparison_criterion; we alse set it explicitly, though, for clarity
if new_scorers != self.scorers:
self.scorers = new_scorers
self.controls.comparison_criterion.model()[:] = \
[scorer.long_name or scorer.name for scorer in self.scorers]
self.comparison_criterion = 0
if self.__pending_comparison_criterion is not None:
# Check for the unlikely case that some scorers have been removed
# from modules
if self.__pending_comparison_criterion < len(self.scorers):
self.comparison_criterion = self.__pending_comparison_criterion
self.__pending_comparison_criterion = None
self._update_compbox_title()

def _update_compbox_title(self):
criterion = self.comparison_criterion
if criterion < len(self.scorers):
scorer = self.scorers[criterion]()
self.compbox.setTitle(f"Model Comparison by {scorer.name}")
else:
self.compbox.setTitle(f"Model Comparison")

@Inputs.preprocessor
def set_preprocessor(self, preproc):
Expand All @@ -453,6 +526,7 @@ def handleNewSignals(self):
"""Reimplemented from OWWidget.handleNewSignals."""
self._update_class_selection()
self.score_table.update_header(self.scorers)
self._update_view_enabled()
self.update_stats_model()
if self.__needupdate:
self.__update()
Expand All @@ -470,9 +544,19 @@ def shuffle_split_changed(self):
self._param_changed()

def _param_changed(self):
self.modcompbox.setEnabled(self.resampling == OWTestLearners.KFold)
self._update_view_enabled()
self._invalidate()
self.__update()

def _update_view_enabled(self):
self.comparison_table.setEnabled(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why disabling the table when nothing can be clicked anyway?
The upper (Evaluation Results) table is never disabled, even when no data is present.
Besides, when removing the data from the widget, the comparison table is enabled, even though holding no data.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like disabling it because it shows the user that it's intentionally blank. Otherwise it looks like a bug when the upper table is filled and this one isn't (e.g. when using Leave one out). Hiding would also be an option, though I like disabling better -- like "something could be here, but currently isn't because I can't compute it in this situation".

I can disable it when there is no data. But in this case we should do the same with the above table, I suppose. We need to discuss this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both views now disable under same conditions.

self.resampling == OWTestLearners.KFold
and len(self.learners) > 1
and self.data is not None)
self.score_table.view.setEnabled(
self.data is not None)

def update_stats_model(self):
# Update the results_model with up to date scores.
# Note: The target class specific scores (if requested) are
Expand All @@ -494,8 +578,10 @@ def update_stats_model(self):
errors = []
has_missing_scores = False

names = []
for key, slot in self.learners.items():
name = learner_name(slot.learner)
names.append(name)
head = QStandardItem(name)
head.setData(key, Qt.UserRole)
results = slot.results
Expand Down Expand Up @@ -558,10 +644,123 @@ def update_stats_model(self):
header.sortIndicatorSection(),
header.sortIndicatorOrder()
)
self._set_comparison_headers(names)

self.error("\n".join(errors), shown=bool(errors))
self.Warning.scores_not_computed(shown=has_missing_scores)

def _on_use_rope_changed(self):
self.controls.rope.setEnabled(self.use_rope)
self.update_comparison_table()

def update_comparison_table(self):
self.comparison_table.clearContents()
slots = self._successful_slots()
if not (slots and self.scorers):
return
names = [learner_name(slot.learner) for slot in slots]
self._set_comparison_headers(names)
if self.resampling == OWTestLearners.KFold:
scores = self._scores_by_folds(slots)
self._fill_table(names, scores)

def _successful_slots(self):
model = self.score_table.model
proxy = self.score_table.sorted_model

keys = (model.data(proxy.mapToSource(proxy.index(row, 0)), Qt.UserRole)
for row in range(proxy.rowCount()))
slots = [slot for slot in (self.learners[key] for key in keys)
if slot.results is not None and slot.results.success]
return slots

def _set_comparison_headers(self, names):
table = self.comparison_table
try:
# Prevent glitching during update
table.setUpdatesEnabled(False)
header = table.horizontalHeader()
if len(names) > 2:
header.setSectionResizeMode(QHeaderView.Stretch)
else:
header.setSectionResizeMode(QHeaderView.Fixed)
table.setRowCount(len(names))
table.setColumnCount(len(names))
table.setVerticalHeaderLabels(names)
table.setHorizontalHeaderLabels(names)
finally:
table.setUpdatesEnabled(True)

def _scores_by_folds(self, slots):
scorer = self.scorers[self.comparison_criterion]()
VesnaT marked this conversation as resolved.
Show resolved Hide resolved
self._update_compbox_title()
if scorer.is_binary:
if self.class_selection != self.TARGET_AVERAGE:
class_var = self.data.domain.class_var
target_index = class_var.values.index(self.class_selection)
kw = dict(target=target_index)
else:
kw = dict(average='weighted')
else:
kw = {}

def call_scorer(results):
def thunked():
return scorer.scores_by_folds(results.value, **kw).flatten()

return thunked

scores = [Try(call_scorer(slot.results)) for slot in slots]
scores = [score.value if score.success else None for score in scores]
# `None in scores doesn't work -- these are np.arrays)
if any(score is None for score in scores):
self.Warning.scores_not_computed()
return scores

def _fill_table(self, names, scores):
table = self.comparison_table
for row, row_name, row_scores in zip(count(), names, scores):
for col, col_name, col_scores in zip(range(row), names, scores):
if row_scores is None or col_scores is None:
continue
if self.use_rope and self.rope:
p0, rope, p1 = baycomp.two_on_single(
row_scores, col_scores, self.rope)
if np.isnan(p0) or np.isnan(rope) or np.isnan(p1):
self._set_cells_na(table, row, col)
continue
self._set_cell(table, row, col,
f"{p0:.3f}<br/><small>{rope:.3f}</small>",
f"p({row_name} > {col_name}) = {p0:.3f}\n"
f"p({row_name} = {col_name}) = {rope:.3f}")
self._set_cell(table, col, row,
f"{p1:.3f}<br/><small>{rope:.3f}</small>",
f"p({col_name} > {row_name}) = {p1:.3f}\n"
f"p({col_name} = {row_name}) = {rope:.3f}")
else:
p0, p1 = baycomp.two_on_single(row_scores, col_scores)
if np.isnan(p0) or np.isnan(p1):
self._set_cells_na(table, row, col)
continue
self._set_cell(table, row, col,
f"{p0:.3f}",
f"p({row_name} > {col_name}) = {p0:.3f}")
self._set_cell(table, col, row,
f"{p1:.3f}",
f"p({col_name} > {row_name}) = {p1:.3f}")

@classmethod
def _set_cells_na(cls, table, row, col):
cls._set_cell(table, row, col, "NA", "comparison cannot be computed")
cls._set_cell(table, col, row, "NA", "comparison cannot be computed")

@staticmethod
def _set_cell(table, row, col, label, tooltip):
item = QLabel(label)
item.setToolTip(tooltip)
item.setAlignment(Qt.AlignCenter)
table.setCellWidget(row, col, item)

def _update_class_selection(self):
self.class_selection_combo.setCurrentIndex(-1)
self.class_selection_combo.clear()
Expand All @@ -585,6 +784,7 @@ def _update_class_selection(self):

def _on_target_class_changed(self):
self.update_stats_model()
self.update_comparison_table()

def _invalidate(self, which=None):
self.cancel()
Expand All @@ -611,6 +811,8 @@ def _invalidate(self, which=None):
item.setData(None, Qt.DisplayRole)
item.setData(None, Qt.ToolTipRole)

self.comparison_table.clearContents()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only clears the contents, but retains the headers.
I'm not sure where the right place to remove headers is, but it should be handled somewhere (once you remove the learners from the widget, their names are still present in the comparison table).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.


self.__needupdate = True

def commit(self):
Expand Down Expand Up @@ -866,6 +1068,7 @@ def __task_complete(self, f: 'Future[Results]'):

self.score_table.update_header(self.scorers)
self.update_stats_model()
self.update_comparison_table()

self.commit()

Expand Down
Loading