Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] ROC analysis: show thresholds #3172

Merged
merged 8 commits into from
Sep 12, 2018
69 changes: 67 additions & 2 deletions Orange/widgets/evaluate/owrocanalysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import numpy
import sklearn.metrics as skl_metrics

from AnyQt.QtWidgets import QListView, QLabel, QGridLayout, QFrame, QAction
from AnyQt.QtGui import QColor, QPen, QBrush, QPainter, QPalette, QFont
from AnyQt.QtWidgets import QListView, QLabel, QGridLayout, QFrame, QAction, QToolTip
from AnyQt.QtGui import QColor, QPen, QBrush, QPainter, QPalette, QFont, QCursor
from AnyQt.QtCore import Qt
import pyqtgraph as pg

Expand Down Expand Up @@ -336,6 +336,7 @@ def __init__(self):
self._plot_curves = {}
self._rocch = None
self._perf_line = None
self._tooltip_cache = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not understand why the tooltip cache is necessary. Can you add a comment explaining the reasoning for its use?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My reasoning behind it was that instead of checking all points of a curve to find the closest one, it would first check the tooltip cache, which contains at most 1 threshold (point) per (selected) classifier.

It's hard to say if it's worth the additional 20-ish lines of code though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it already checks all points in pts = sp.pointsAt(act_pos)

Copy link
Contributor Author

@matejklemen matejklemen Aug 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but it also does it again later (although it is vectorized): idx_closest = numpy.argmin(numpy.linalg.norm(diff, axis=1)).

Considering your point, it does seem that the performance gain from tooltip cache might not be as big as I thought. Should I remove it and make the code clearer that way?

Edit: also, the tooltip cache provides a convenient way to test the contents of tooltips - the first alternative way I can think of would be to test the entire tooltip text, e.g. assertEqual("Tooltips:\n(#1) 0.400", QTooltip.text()).


box = gui.vBox(self.controlArea, "Plot")
tbox = gui.vBox(box, "Target Class")
Expand Down Expand Up @@ -395,6 +396,7 @@ def __init__(self):

self.plotview = pg.GraphicsView(background="w")
self.plotview.setFrameStyle(QFrame.StyledPanel)
self.plotview.scene().sigMouseMoved.connect(self._on_mouse_moved)

self.plot = pg.PlotItem(enableMenu=False)
self.plot.setMouseEnabled(False, False)
Expand Down Expand Up @@ -445,6 +447,7 @@ def clear(self):
self._plot_curves = {}
self._rocch = None
self._perf_line = None
self._tooltip_cache = None

def _initialize(self, results):
names = getattr(results, "learner_names", None)
Expand Down Expand Up @@ -601,6 +604,68 @@ def _setup_plot(self):
warning = "All ROC curves are undefined"
self.warning(warning)

def _on_mouse_moved(self, pos):
target = self.target_index
selected = self.selected_classifiers
curves = [(clf_idx, self.plot_curves(target, clf_idx))
for clf_idx in selected] # type: List[Tuple[int, plot_curves]]
valid_thresh, valid_clf = [], []
pt, ave_mode = None, self.roc_averaging

for clf_idx, crv in curves:
if self.roc_averaging == OWROCAnalysis.Merge:
curve = crv.merge()
elif self.roc_averaging == OWROCAnalysis.Vertical:
curve = crv.avg_vertical()
elif self.roc_averaging == OWROCAnalysis.Threshold:
curve = crv.avg_threshold()
else:
# currently not implemented for 'Show Individual Curves'
return

sp = curve.curve_item.childItems()[0] # type: pg.ScatterPlotItem
act_pos = sp.mapFromScene(pos)
pts = sp.pointsAt(act_pos)

if len(pts) > 0:
mouse_pt = pts[0].pos()
if self._tooltip_cache:
cache_pt, cache_thresh, cache_clf, cache_ave = self._tooltip_cache
curr_thresh, curr_clf = [], []
if numpy.linalg.norm(mouse_pt - cache_pt) < 10e-6 \
and cache_ave == self.roc_averaging:
mask = numpy.equal(cache_clf, clf_idx)
curr_thresh = numpy.compress(mask, cache_thresh).tolist()
curr_clf = numpy.compress(mask, cache_clf).tolist()
else:
QToolTip.showText(QCursor.pos(), "")
self._tooltip_cache = None

if curr_thresh:
valid_thresh.append(*curr_thresh)
valid_clf.append(*curr_clf)
pt = cache_pt
continue

curve_pts = curve.curve.points
roc_points = numpy.column_stack((curve_pts.fpr, curve_pts.tpr))
diff = numpy.subtract(roc_points, mouse_pt)
# Find closest point on curve and save the corresponding threshold
idx_closest = numpy.argmin(numpy.linalg.norm(diff, axis=1))

thresh = curve_pts.thresholds[idx_closest]
if not numpy.isnan(thresh):
valid_thresh.append(thresh)
valid_clf.append(clf_idx)
pt = [curve_pts.fpr[idx_closest], curve_pts.tpr[idx_closest]]

if valid_thresh:
clf_names = self.classifier_names
msg = "Thresholds:\n" + "\n".join(["({:s}) {:.3f}".format(clf_names[i], thresh)
for i, thresh in zip(valid_clf, valid_thresh)])
QToolTip.showText(QCursor.pos(), msg)
self._tooltip_cache = (pt, valid_thresh, valid_clf, ave_mode)

def _on_target_changed(self):
self.plot.clear()
self._setup_plot()
Expand Down
60 changes: 60 additions & 0 deletions Orange/widgets/evaluate/tests/test_owrocanalysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import copy
import numpy as np

from AnyQt.QtWidgets import QToolTip

import Orange.data
import Orange.evaluation
import Orange.classification
Expand All @@ -12,6 +14,7 @@
from Orange.widgets.evaluate.owrocanalysis import OWROCAnalysis
from Orange.widgets.evaluate.tests.base import EvaluateTest
from Orange.widgets.tests.base import WidgetTest
from Orange.widgets.tests.utils import mouseMove


class TestROC(unittest.TestCase):
Expand Down Expand Up @@ -156,3 +159,60 @@ def test_nan_input(self):
self.assertTrue(self.widget.Error.invalid_results.is_shown())
self.send_signal(self.widget.Inputs.evaluation_results, None)
self.assertFalse(self.widget.Error.invalid_results.is_shown())

def test_tooltips(self):
data_in = Orange.data.Table("titanic")
res = Orange.evaluation.TestOnTrainingData(
data=data_in,
learners=[Orange.classification.KNNLearner(),
Orange.classification.LogisticRegressionLearner()],
store_data=True
)

self.send_signal(self.widget.Inputs.evaluation_results, res)
self.widget.roc_averaging = OWROCAnalysis.Merge
self.widget.target_index = 0
self.widget.selected_classifiers = [0, 1]
vb = self.widget.plot.getViewBox()
vb.childTransform() # Force pyqtgraph to update transforms

curve = self.widget.plot_curves(self.widget.target_index, 0)
curve_merge = curve.merge()
view = self.widget.plotview
item = curve_merge.curve_item # type: pg.PlotCurveItem

# no tooltips to be shown
pos = item.mapToScene(0.0, 1.0)
pos = view.mapFromScene(pos)
mouseMove(view.viewport(), pos)
self.assertIs(self.widget._tooltip_cache, None)

# test single point
pos = item.mapToScene(0.22504, 0.45400)
pos = view.mapFromScene(pos)
mouseMove(view.viewport(), pos)
shown_thresh = self.widget._tooltip_cache[1]
self.assertTrue(QToolTip.isVisible())
np.testing.assert_almost_equal(shown_thresh, [0.40000], decimal=5)

pos = item.mapToScene(0.0, 0.0)
pos = view.mapFromScene(pos)
# test overlapping points
mouseMove(view.viewport(), pos)
shown_thresh = self.widget._tooltip_cache[1]
self.assertTrue(QToolTip.isVisible())
np.testing.assert_almost_equal(shown_thresh, [1.8, 1.89336], decimal=5)

# test that cache is invalidated when changing averaging mode
self.widget.roc_averaging = OWROCAnalysis.Threshold
self.widget._replot()
mouseMove(view.viewport(), pos)
shown_thresh = self.widget._tooltip_cache[1]
self.assertTrue(QToolTip.isVisible())
np.testing.assert_almost_equal(shown_thresh, [1, 1])

# test nan thresholds
self.widget.roc_averaging = OWROCAnalysis.Vertical
self.widget._replot()
mouseMove(view.viewport(), pos)
self.assertIs(self.widget._tooltip_cache, None)
16 changes: 15 additions & 1 deletion Orange/widgets/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import warnings
import contextlib

from AnyQt.QtCore import Qt, QObject, QEventLoop, QTimer, QLocale
from AnyQt.QtCore import Qt, QObject, QEventLoop, QTimer, QLocale, QPoint
from AnyQt.QtTest import QTest
from AnyQt.QtGui import QMouseEvent
from AnyQt.QtWidgets import QApplication


class EventSpy(QObject):
Expand Down Expand Up @@ -303,3 +305,15 @@ def wrap(*args, **kwargs):
return result
return wrap
return wrapper


def mouseMove(widget, pos=QPoint(), delay=-1): # pragma: no-cover
# Like QTest.mouseMove, but functional without QCursor.setPos
if pos.isNull():
pos = widget.rect().center()
me = QMouseEvent(QMouseEvent.MouseMove, pos, widget.mapToGlobal(pos),
Qt.NoButton, Qt.MouseButtons(0), Qt.NoModifier)
if delay > 0:
QTest.qWait(delay)

QApplication.sendEvent(widget, me)
16 changes: 2 additions & 14 deletions Orange/widgets/utils/tests/test_combobox.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import unittest

from AnyQt.QtCore import Qt, QPoint, QRect
from AnyQt.QtGui import QMouseEvent
from AnyQt.QtCore import Qt, QRect
from AnyQt.QtWidgets import QListView, QApplication
from AnyQt.QtTest import QTest, QSignalSpy
from Orange.widgets.tests.base import GuiTest
from Orange.widgets.tests.utils import mouseMove

from Orange.widgets.utils import combobox

Expand Down Expand Up @@ -133,15 +133,3 @@ def test_popup_util(self):
geom, QRect(0, 500, 100, 20), screen
)
self.assertEqual(g4, QRect(0, 500 - 400, 100, 400))


def mouseMove(widget, pos=QPoint(), delay=-1): # pragma: no-cover
# Like QTest.mouseMove, but functional without QCursor.setPos
if pos.isNull():
pos = widget.rect().center()
me = QMouseEvent(QMouseEvent.MouseMove, pos, widget.mapToGlobal(pos),
Qt.NoButton, Qt.MouseButtons(0), Qt.NoModifier)
if delay > 0:
QTest.qWait(delay)

QApplication.sendEvent(widget, me)