Skip to content

Commit

Permalink
Merge pull request #3172 from matejklemen/enh_roc_thresholds
Browse files Browse the repository at this point in the history
[ENH] ROC analysis: show thresholds
  • Loading branch information
ales-erjavec authored Sep 12, 2018
2 parents a95245e + 099ee74 commit eb9c1e0
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 17 deletions.
69 changes: 67 additions & 2 deletions Orange/widgets/evaluate/owrocanalysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import numpy
import sklearn.metrics as skl_metrics

from AnyQt.QtWidgets import QListView, QLabel, QGridLayout, QFrame, QAction
from AnyQt.QtGui import QColor, QPen, QBrush, QPainter, QPalette, QFont
from AnyQt.QtWidgets import QListView, QLabel, QGridLayout, QFrame, QAction, QToolTip
from AnyQt.QtGui import QColor, QPen, QBrush, QPainter, QPalette, QFont, QCursor
from AnyQt.QtCore import Qt
import pyqtgraph as pg

Expand Down Expand Up @@ -336,6 +336,7 @@ def __init__(self):
self._plot_curves = {}
self._rocch = None
self._perf_line = None
self._tooltip_cache = None

box = gui.vBox(self.controlArea, "Plot")
tbox = gui.vBox(box, "Target Class")
Expand Down Expand Up @@ -395,6 +396,7 @@ def __init__(self):

self.plotview = pg.GraphicsView(background="w")
self.plotview.setFrameStyle(QFrame.StyledPanel)
self.plotview.scene().sigMouseMoved.connect(self._on_mouse_moved)

self.plot = pg.PlotItem(enableMenu=False)
self.plot.setMouseEnabled(False, False)
Expand Down Expand Up @@ -445,6 +447,7 @@ def clear(self):
self._plot_curves = {}
self._rocch = None
self._perf_line = None
self._tooltip_cache = None

def _initialize(self, results):
names = getattr(results, "learner_names", None)
Expand Down Expand Up @@ -601,6 +604,68 @@ def _setup_plot(self):
warning = "All ROC curves are undefined"
self.warning(warning)

def _on_mouse_moved(self, pos):
target = self.target_index
selected = self.selected_classifiers
curves = [(clf_idx, self.plot_curves(target, clf_idx))
for clf_idx in selected] # type: List[Tuple[int, plot_curves]]
valid_thresh, valid_clf = [], []
pt, ave_mode = None, self.roc_averaging

for clf_idx, crv in curves:
if self.roc_averaging == OWROCAnalysis.Merge:
curve = crv.merge()
elif self.roc_averaging == OWROCAnalysis.Vertical:
curve = crv.avg_vertical()
elif self.roc_averaging == OWROCAnalysis.Threshold:
curve = crv.avg_threshold()
else:
# currently not implemented for 'Show Individual Curves'
return

sp = curve.curve_item.childItems()[0] # type: pg.ScatterPlotItem
act_pos = sp.mapFromScene(pos)
pts = sp.pointsAt(act_pos)

if len(pts) > 0:
mouse_pt = pts[0].pos()
if self._tooltip_cache:
cache_pt, cache_thresh, cache_clf, cache_ave = self._tooltip_cache
curr_thresh, curr_clf = [], []
if numpy.linalg.norm(mouse_pt - cache_pt) < 10e-6 \
and cache_ave == self.roc_averaging:
mask = numpy.equal(cache_clf, clf_idx)
curr_thresh = numpy.compress(mask, cache_thresh).tolist()
curr_clf = numpy.compress(mask, cache_clf).tolist()
else:
QToolTip.showText(QCursor.pos(), "")
self._tooltip_cache = None

if curr_thresh:
valid_thresh.append(*curr_thresh)
valid_clf.append(*curr_clf)
pt = cache_pt
continue

curve_pts = curve.curve.points
roc_points = numpy.column_stack((curve_pts.fpr, curve_pts.tpr))
diff = numpy.subtract(roc_points, mouse_pt)
# Find closest point on curve and save the corresponding threshold
idx_closest = numpy.argmin(numpy.linalg.norm(diff, axis=1))

thresh = curve_pts.thresholds[idx_closest]
if not numpy.isnan(thresh):
valid_thresh.append(thresh)
valid_clf.append(clf_idx)
pt = [curve_pts.fpr[idx_closest], curve_pts.tpr[idx_closest]]

if valid_thresh:
clf_names = self.classifier_names
msg = "Thresholds:\n" + "\n".join(["({:s}) {:.3f}".format(clf_names[i], thresh)
for i, thresh in zip(valid_clf, valid_thresh)])
QToolTip.showText(QCursor.pos(), msg)
self._tooltip_cache = (pt, valid_thresh, valid_clf, ave_mode)

def _on_target_changed(self):
self.plot.clear()
self._setup_plot()
Expand Down
60 changes: 60 additions & 0 deletions Orange/widgets/evaluate/tests/test_owrocanalysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import copy
import numpy as np

from AnyQt.QtWidgets import QToolTip

import Orange.data
import Orange.evaluation
import Orange.classification
Expand All @@ -12,6 +14,7 @@
from Orange.widgets.evaluate.owrocanalysis import OWROCAnalysis
from Orange.widgets.evaluate.tests.base import EvaluateTest
from Orange.widgets.tests.base import WidgetTest
from Orange.widgets.tests.utils import mouseMove


class TestROC(unittest.TestCase):
Expand Down Expand Up @@ -156,3 +159,60 @@ def test_nan_input(self):
self.assertTrue(self.widget.Error.invalid_results.is_shown())
self.send_signal(self.widget.Inputs.evaluation_results, None)
self.assertFalse(self.widget.Error.invalid_results.is_shown())

def test_tooltips(self):
data_in = Orange.data.Table("titanic")
res = Orange.evaluation.TestOnTrainingData(
data=data_in,
learners=[Orange.classification.KNNLearner(),
Orange.classification.LogisticRegressionLearner()],
store_data=True
)

self.send_signal(self.widget.Inputs.evaluation_results, res)
self.widget.roc_averaging = OWROCAnalysis.Merge
self.widget.target_index = 0
self.widget.selected_classifiers = [0, 1]
vb = self.widget.plot.getViewBox()
vb.childTransform() # Force pyqtgraph to update transforms

curve = self.widget.plot_curves(self.widget.target_index, 0)
curve_merge = curve.merge()
view = self.widget.plotview
item = curve_merge.curve_item # type: pg.PlotCurveItem

# no tooltips to be shown
pos = item.mapToScene(0.0, 1.0)
pos = view.mapFromScene(pos)
mouseMove(view.viewport(), pos)
self.assertIs(self.widget._tooltip_cache, None)

# test single point
pos = item.mapToScene(0.22504, 0.45400)
pos = view.mapFromScene(pos)
mouseMove(view.viewport(), pos)
shown_thresh = self.widget._tooltip_cache[1]
self.assertTrue(QToolTip.isVisible())
np.testing.assert_almost_equal(shown_thresh, [0.40000], decimal=5)

pos = item.mapToScene(0.0, 0.0)
pos = view.mapFromScene(pos)
# test overlapping points
mouseMove(view.viewport(), pos)
shown_thresh = self.widget._tooltip_cache[1]
self.assertTrue(QToolTip.isVisible())
np.testing.assert_almost_equal(shown_thresh, [1.8, 1.89336], decimal=5)

# test that cache is invalidated when changing averaging mode
self.widget.roc_averaging = OWROCAnalysis.Threshold
self.widget._replot()
mouseMove(view.viewport(), pos)
shown_thresh = self.widget._tooltip_cache[1]
self.assertTrue(QToolTip.isVisible())
np.testing.assert_almost_equal(shown_thresh, [1, 1])

# test nan thresholds
self.widget.roc_averaging = OWROCAnalysis.Vertical
self.widget._replot()
mouseMove(view.viewport(), pos)
self.assertIs(self.widget._tooltip_cache, None)
16 changes: 15 additions & 1 deletion Orange/widgets/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import warnings
import contextlib

from AnyQt.QtCore import Qt, QObject, QEventLoop, QTimer, QLocale
from AnyQt.QtCore import Qt, QObject, QEventLoop, QTimer, QLocale, QPoint
from AnyQt.QtTest import QTest
from AnyQt.QtGui import QMouseEvent
from AnyQt.QtWidgets import QApplication


class EventSpy(QObject):
Expand Down Expand Up @@ -303,3 +305,15 @@ def wrap(*args, **kwargs):
return result
return wrap
return wrapper


def mouseMove(widget, pos=QPoint(), delay=-1): # pragma: no-cover
# Like QTest.mouseMove, but functional without QCursor.setPos
if pos.isNull():
pos = widget.rect().center()
me = QMouseEvent(QMouseEvent.MouseMove, pos, widget.mapToGlobal(pos),
Qt.NoButton, Qt.MouseButtons(0), Qt.NoModifier)
if delay > 0:
QTest.qWait(delay)

QApplication.sendEvent(widget, me)
16 changes: 2 additions & 14 deletions Orange/widgets/utils/tests/test_combobox.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import unittest

from AnyQt.QtCore import Qt, QPoint, QRect
from AnyQt.QtGui import QMouseEvent
from AnyQt.QtCore import Qt, QRect
from AnyQt.QtWidgets import QListView, QApplication
from AnyQt.QtTest import QTest, QSignalSpy
from Orange.widgets.tests.base import GuiTest
from Orange.widgets.tests.utils import mouseMove

from Orange.widgets.utils import combobox

Expand Down Expand Up @@ -133,15 +133,3 @@ def test_popup_util(self):
geom, QRect(0, 500, 100, 20), screen
)
self.assertEqual(g4, QRect(0, 500 - 400, 100, 400))


def mouseMove(widget, pos=QPoint(), delay=-1): # pragma: no-cover
# Like QTest.mouseMove, but functional without QCursor.setPos
if pos.isNull():
pos = widget.rect().center()
me = QMouseEvent(QMouseEvent.MouseMove, pos, widget.mapToGlobal(pos),
Qt.NoButton, Qt.MouseButtons(0), Qt.NoModifier)
if delay > 0:
QTest.qWait(delay)

QApplication.sendEvent(widget, me)

0 comments on commit eb9c1e0

Please sign in to comment.