Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OWCorrelations: Use heuristic to get the most promising attribute pairs #70

Merged
merged 10 commits into from
Oct 5, 2018
153 changes: 126 additions & 27 deletions orangecontrib/prototypes/widgets/owcorrelations.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,156 @@
"""
Correlations widget
"""
from enum import IntEnum
from operator import attrgetter
from itertools import combinations, groupby, chain

import numpy as np
from Orange.widgets.utils.signals import Input, Output
from scipy.stats import spearmanr
from scipy.stats import spearmanr, pearsonr
from sklearn.cluster import KMeans

from AnyQt.QtCore import Qt, QItemSelectionModel, QItemSelection, QSize
from AnyQt.QtGui import QStandardItem
from AnyQt.QtWidgets import QHeaderView
from AnyQt.QtGui import QStandardItem, QColor

from Orange.data import Table, Domain, ContinuousVariable, StringVariable
from Orange.preprocess import SklImpute
from Orange.preprocess import SklImpute, Normalize
from Orange.widgets import gui
from Orange.widgets.settings import Setting, ContextSetting, \
DomainContextHandler
from Orange.widgets.utils.signals import Input, Output
from Orange.widgets.visualize.utils import VizRankDialogAttrPair
from Orange.widgets.widget import OWWidget, AttributeList, Msg

NAN = 2
SIZE_LIMIT = 1000000


class CorrelationType(IntEnum):
"""
Correlation type enumerator. Possible correlations: Pearson, Spearman.
"""
PEARSON, SPEARMAN = 0, 1

@staticmethod
def items():
return ["Pairwise Pearson correlation", "Pairwise Spearman correlation"]
"""
Texts for correlation types. Can be used in gui controls (eg. combobox).
"""
return ["Pearson correlation", "Spearman correlation"]


class KMeansCorrelationHeuristic:
"""
Heuristic to obtain the most promising attribute pairs, when there are to
many attributes to calculate correlations for all possible pairs.
"""
n_clusters = 10

def __init__(self, data):
self.n_attributes = len(data.domain.attributes)
self.data = data
self.states = None

def get_clusters_of_attributes(self):
"""
Generates groupes of attribute IDs, grouped by cluster. Clusters are
obtained by KMeans algorithm.

:return: generator of attributes grouped by cluster
"""
data = Normalize()(self.data).X.T
kmeans = KMeans(n_clusters=self.n_clusters, random_state=0).fit(data)
labels_attrs = sorted([(l, i) for i, l in enumerate(kmeans.labels_)])
for _, group in groupby(labels_attrs, key=lambda x: x[0]):
group = list(group)
if len(group) > 1:
yield list(pair[1] for pair in group)

def get_states(self, initial_state):
"""
Generates the most promising states (attribute pairs).

:param initial_state: initial state; None if this is the first call
:return: generator of tuples of states
"""
if self.states is not None:
return chain([initial_state], self.states)
self.states = chain.from_iterable(combinations(inds, 2) for inds in
self.get_clusters_of_attributes())
return self.states


class CorrelationRank(VizRankDialogAttrPair):
"""
Correlations rank widget.
"""
NEGATIVE_COLOR = QColor(70, 190, 250)
POSITIVE_COLOR = QColor(170, 242, 43)

def __init__(self, *args):
super().__init__(*args)
self.heuristic = None
self.use_heuristic = False

def initialize(self):
super().initialize()
data = self.master.cont_data
self.attrs = data and data.domain.attributes
self.model_proxy.setFilterKeyColumn(-1)
self.rank_table.horizontalHeader().setStretchLastSection(False)
self.heuristic = None
self.use_heuristic = False
if data:
# use heuristic if data is too big
n_attrs = len(self.attrs)
use_heuristic = n_attrs > KMeansCorrelationHeuristic.n_clusters
self.use_heuristic = use_heuristic and \
len(data) * n_attrs ** 2 > SIZE_LIMIT
if self.use_heuristic:
self.heuristic = KMeansCorrelationHeuristic(data)

def compute_score(self, state):
(a1, a2), corr_type = state, self.master.correlation_type
if corr_type == CorrelationType.PEARSON:
return -np.corrcoef(self.master.cont_data.X[:, [a1, a2]].T)[0, 1]
else:
return -spearmanr(self.master.cont_data.X[:, [a1, a2]])[0]
(attr1, attr2), corr_type = state, self.master.correlation_type
data = self.master.cont_data.X
corr = pearsonr if corr_type == CorrelationType.PEARSON else spearmanr
result = corr(data[:, attr1], data[:, attr2])[0]
return -abs(result) if not np.isnan(result) else NAN, result

def row_for_state(self, score, state):
attrs = sorted((self.attrs[x] for x in state), key=attrgetter("name"))
attr_1_item = QStandardItem(attrs[0].name)
attr_2_item = QStandardItem(attrs[1].name)
correlation_item = QStandardItem(str(round(-score, 3)))
attr_1_item.setData(attrs, self._AttrRole)
attr_2_item.setData(attrs, self._AttrRole)
correlation_item.setData(attrs)
correlation_item.setData(Qt.AlignCenter, Qt.TextAlignmentRole)
return [attr_1_item, attr_2_item, correlation_item]
attrs_item = QStandardItem(
"{}, {}".format(attrs[0].name, attrs[1].name))
attrs_item.setData(attrs, self._AttrRole)
attrs_item.setData(Qt.AlignLeft + Qt.AlignTop, Qt.TextAlignmentRole)
correlation_item = QStandardItem("{:+.3f}".format(score[1]))
correlation_item.setData(attrs, self._AttrRole)
correlation_item.setData(
self.NEGATIVE_COLOR if score[1] < 0 else self.POSITIVE_COLOR,
gui.TableBarItem.BarColorRole)
return [correlation_item, attrs_item]

def check_preconditions(self):
return self.master.cont_data is not None

def iterate_states(self, initial_state):
if self.use_heuristic:
return self.heuristic.get_states(initial_state)
else:
return super().iterate_states(initial_state)

def state_count(self):
if self.use_heuristic:
n_clusters = KMeansCorrelationHeuristic.n_clusters
n_avg_attrs = len(self.attrs) / n_clusters
return n_clusters * n_avg_attrs * (n_avg_attrs - 1) / 2
else:
n_attrs = len(self.attrs)
return n_attrs * (n_attrs - 1) / 2

@staticmethod
def bar_length(score):
return abs(score[1])


class OWCorrelations(OWWidget):
name = "Correlations"
Expand Down Expand Up @@ -93,6 +189,7 @@ def __init__(self):

self.vizrank, _ = CorrelationRank.add_vizrank(
None, self, None, self._vizrank_selection_changed)
self.vizrank.progressBar = self.progressBar

gui.separator(box)
box.layout().addWidget(self.vizrank.filter)
Expand All @@ -114,10 +211,12 @@ def _vizrank_selection_changed(self, *args):
def _vizrank_select(self):
model = self.vizrank.rank_table.model()
selection = QItemSelection()
names = sorted(x.name for x in self.selection)
for i in range(model.rowCount()):
if model.data(model.index(i, 0)) == self.selection[0].name and \
model.data(model.index(i, 1)) == self.selection[1].name:
selection.select(model.index(i, 0), model.index(i, 2))
if sorted(x.name for x in model.data(model.index(i, 0),
CorrelationRank._AttrRole)) \
== names:
selection.select(model.index(i, 0), model.index(i, 1))
self.vizrank.rank_table.selectionModel().select(
selection, QItemSelectionModel.ClearAndSelect)
break
Expand Down Expand Up @@ -150,7 +249,6 @@ def apply(self):
self.vizrank.toggle()
header = self.vizrank.rank_table.horizontalHeader()
header.setStretchLastSection(True)
header.setSectionResizeMode(QHeaderView.ResizeToContents)
else:
self.commit()

Expand All @@ -164,11 +262,12 @@ def commit(self):
metas = [StringVariable("Feature 1"), StringVariable("Feature 2")]
domain = Domain([ContinuousVariable("Correlation")], metas=metas)
model = self.vizrank.rank_model
x = np.array([[float(model.data(model.index(row, 2)))] for row
x = np.array([[float(model.data(model.index(row, 0)))] for row
in range(model.rowCount())])
m = np.array([[model.data(model.index(row, 0)),
model.data(model.index(row, 1))] for row
in range(model.rowCount())], dtype=object)
m = np.array([[attr.name
for attr in model.data(model.index(row, 0),
CorrelationRank._AttrRole)]
for row in range(model.rowCount())], dtype=object)
corr_table = Table(domain, x, metas=m)
corr_table.name = "Correlations"

Expand Down
78 changes: 56 additions & 22 deletions orangecontrib/prototypes/widgets/tests/test_owcorrelations.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Test methods with long descriptive names can omit docstrings
# pylint: disable=missing-docstring
import time
from Orange.data import Table
from Orange.widgets.visualize.owscatterplot import OWScatterPlot
from orangecontrib.prototypes.widgets.owcorrelations import OWCorrelations
from Orange.widgets.tests.base import WidgetTest
from Orange.widgets.widget import AttributeList
from orangecontrib.prototypes.widgets.owcorrelations import OWCorrelations, \
KMeansCorrelationHeuristic


class TestOWCorrelations(WidgetTest):
Expand All @@ -20,75 +22,107 @@ def setUp(self):

def test_input_data_cont(self):
"""Check correlation table for dataset with continuous attributes"""
self.send_signal("Data", self.data_cont)
self.send_signal(self.widget.Inputs.data, self.data_cont)
time.sleep(0.1)
n_attrs = len(self.data_cont.domain.attributes)
self.assertEqual(self.widget.vizrank.rank_model.columnCount(), 3)
self.process_events()
self.assertEqual(self.widget.vizrank.rank_model.columnCount(), 2)
self.assertEqual(self.widget.vizrank.rank_model.rowCount(),
n_attrs * (n_attrs - 1) / 2)
self.send_signal("Data", None)
self.send_signal(self.widget.Inputs.data, None)
self.assertEqual(self.widget.vizrank.rank_model.columnCount(), 0)
self.assertEqual(self.widget.vizrank.rank_model.rowCount(), 0)

def test_input_data_disc(self):
"""Check correlation table for dataset with discrete attributes"""
self.send_signal("Data", self.data_disc)
self.send_signal(self.widget.Inputs.data, self.data_disc)
self.assertTrue(self.widget.Information.not_enough_vars.is_shown())
self.send_signal("Data", None)
self.send_signal(self.widget.Inputs.data, None)
self.assertFalse(self.widget.Information.not_enough_vars.is_shown())

def test_input_data_mixed(self):
"""Check correlation table for dataset with continuous and discrete
attributes"""
self.send_signal("Data", self.data_mixed)
self.send_signal(self.widget.Inputs.data, self.data_mixed)
domain = self.data_mixed.domain
n_attrs = len([a for a in domain.attributes if a.is_continuous])
self.assertEqual(self.widget.vizrank.rank_model.columnCount(), 3)
time.sleep(0.1)
self.process_events()
self.assertEqual(self.widget.vizrank.rank_model.columnCount(), 2)
self.assertEqual(self.widget.vizrank.rank_model.rowCount(),
n_attrs * (n_attrs - 1) / 2)

def test_input_data_one_feature(self):
"""Check correlation table for dataset with one attribute"""
self.send_signal("Data", self.data_cont[:, [0, 4]])
self.send_signal(self.widget.Inputs.data, self.data_cont[:, [0, 4]])
self.assertEqual(self.widget.vizrank.rank_model.columnCount(), 0)
self.assertTrue(self.widget.Information.not_enough_vars.is_shown())
self.send_signal("Data", None)
self.send_signal(self.widget.Inputs.data, None)
self.assertFalse(self.widget.Information.not_enough_vars.is_shown())

def test_input_data_one_instance(self):
"""Check correlation table for dataset with one instance"""
self.send_signal("Data", self.data_cont[:1])
self.send_signal(self.widget.Inputs.data, self.data_cont[:1])
self.assertEqual(self.widget.vizrank.rank_model.columnCount(), 0)
self.assertTrue(self.widget.Information.not_enough_inst.is_shown())
self.send_signal("Data", None)
self.send_signal(self.widget.Inputs.data, None)
self.assertFalse(self.widget.Information.not_enough_inst.is_shown())

def test_output_data(self):
"""Check dataset on output"""
self.send_signal("Data", self.data_cont)
self.assertEqual(self.data_cont, self.get_output("Data"))
self.send_signal(self.widget.Inputs.data, self.data_cont)
time.sleep(0.1)
self.process_events()
self.widget.commit()
self.assertEqual(self.data_cont, self.get_output(self.widget.Outputs.data))

def test_output_features(self):
"""Check features on output"""
self.send_signal("Data", self.data_cont)
features = self.get_output("Features")
self.send_signal(self.widget.Inputs.data, self.data_cont)
time.sleep(0.1)
self.process_events()
attrs = self.widget.cont_data.domain.attributes
self.widget._vizrank_selection_changed(attrs[0], attrs[1])
features = self.get_output(self.widget.Outputs.features)
self.assertIsInstance(features, AttributeList)
self.assertEqual(len(features), 2)

def test_output_correlations(self):
"""Check correlation table on on output"""
self.send_signal("Data", self.data_cont)
correlations = self.get_output("Correlations")
self.send_signal(self.widget.Inputs.data, self.data_cont)
time.sleep(0.1)
self.process_events()
self.widget.commit()
correlations = self.get_output(self.widget.Outputs.correlations)
self.assertIsInstance(correlations, Table)
self.assertEqual(len(correlations), 6)
self.assertEqual(len(correlations.domain.attributes), 1)
self.assertEqual(len(correlations.domain.metas), 2)

def test_scatterplot_input_features(self):
"""Check if attributes have been set after sent to scatterplot"""
self.send_signal("Data", self.data_cont)
self.send_signal(self.widget.Inputs.data, self.data_cont)
scatterplot_widget = self.create_widget(OWScatterPlot)
features = self.get_output("Features")
self.send_signal("Data", self.data_cont, widget=scatterplot_widget)
self.send_signal("Features", features, widget=scatterplot_widget)
attrs = self.widget.cont_data.domain.attributes
self.widget._vizrank_selection_changed(attrs[2], attrs[3])
features = self.get_output(self.widget.Outputs.features)
self.send_signal(self.widget.Inputs.data, self.data_cont, widget=scatterplot_widget)
self.send_signal(scatterplot_widget.Inputs.features, features, widget=scatterplot_widget)
self.assertIs(scatterplot_widget.attr_x, self.data_cont.domain[2])
self.assertIs(scatterplot_widget.attr_y, self.data_cont.domain[3])

def test_heuristic(self):
"""Check attribute pairs got by heuristic"""
heuristic = KMeansCorrelationHeuristic(self.data_cont)
heuristic.n_clusters = 2
self.assertListEqual(list(heuristic.get_states(None)),
[(0, 2), (0, 3), (2, 3)])

def test_heuristic_get_states(self):
"""Check attribute pairs after the widget has been paused"""
heuristic = KMeansCorrelationHeuristic(self.data_cont)
heuristic.n_clusters = 2
states = heuristic.get_states(None)
_ = next(states)
self.assertListEqual(list(heuristic.get_states(next(states))),
[(0, 3), (2, 3)])