Skip to content

Commit

Permalink
Merge pull request #5440 from VesnaT/scatter_discrete
Browse files Browse the repository at this point in the history
[ENH] Scatter plot: Bring discrete attributes functionality back
  • Loading branch information
janezd authored May 21, 2021
2 parents 57ca4c7 + 89ed01b commit 96fda39
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 86 deletions.
87 changes: 70 additions & 17 deletions Orange/widgets/visualize/owscatterplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from xml.sax.saxutils import escape

import numpy as np
from AnyQt.QtWidgets import QGroupBox, QPushButton
from scipy.stats import linregress
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import r2_score
Expand All @@ -11,15 +12,17 @@

import pyqtgraph as pg

from Orange.data import Table, Domain, DiscreteVariable, Variable, \
ContinuousVariable
from orangewidget.utils.combobox import ComboBoxSearch

from Orange.data import Table, Domain, DiscreteVariable, Variable
from Orange.data.sql.table import SqlTable, AUTO_DL_LIMIT
from Orange.preprocess.score import ReliefF, RReliefF

from Orange.widgets import gui
from Orange.widgets import gui, report
from Orange.widgets.io import MatplotlibFormat, MatplotlibPDFFormat
from Orange.widgets.settings import (
Setting, ContextSetting, SettingProvider, IncompatibleContext)
from Orange.widgets.utils import get_variable_values_sorted
from Orange.widgets.utils.itemmodels import DomainModel
from Orange.widgets.utils.widgetpreview import WidgetPreview
from Orange.widgets.visualize.owscatterplotgraph import OWScatterPlotBase, \
Expand Down Expand Up @@ -85,7 +88,7 @@ def score_heuristic(self):
assert self.attr_color is not None
master_domain = self.master.data.domain
vars = [v for v in chain(master_domain.variables, master_domain.metas)
if v is not self.attr_color and v.is_continuous]
if v is not self.attr_color and v.is_primitive()]
domain = Domain(attributes=vars, class_vars=self.attr_color)
data = self.master.data.transform(domain)
relief = ReliefF if isinstance(domain.class_var, DiscreteVariable) \
Expand Down Expand Up @@ -117,6 +120,7 @@ def axis_items(self):
class OWScatterPlotGraph(OWScatterPlotBase):
show_reg_line = Setting(False)
orthonormal_regression = Setting(False)
jitter_continuous = Setting(False)

def __init__(self, scatter_widget, parent):
super().__init__(scatter_widget, parent)
Expand All @@ -137,12 +141,34 @@ def update_colors(self):
super().update_colors()
self.update_regression_line()

def jitter_coordinates(self, x, y):
def get_span(attr):
if attr.is_discrete:
# Assuming the maximal jitter size is 10, a span of 4 will
# jitter by 4 * 10 / 100 = 0.4, so there will be no overlap
return 4
elif self.jitter_continuous:
return None # Let _jitter_data determine the span
else:
return 0 # No jittering
span_x = get_span(self.master.attr_x)
span_y = get_span(self.master.attr_y)
if self.jitter_size == 0 or (span_x == 0 and span_y == 0):
return x, y
return self._jitter_data(x, y, span_x, span_y)

def update_axes(self):
for axis, title in self.master.get_axes().items():
use_time = title is not None and title.is_time
for axis, var in self.master.get_axes().items():
axis_item = self.plot_widget.plotItem.getAxis(axis)
if var and var.is_discrete:
ticks = [list(enumerate(get_variable_values_sorted(var)))]
axis_item.setTicks(ticks)
else:
axis_item.setTicks(None)
use_time = var and var.is_time
self.plot_widget.plotItem.getAxis(axis).use_time(use_time)
self.plot_widget.setLabel(axis=axis, text=title or "")
if title is None:
self.plot_widget.setLabel(axis=axis, text=var or "")
if not var:
self.plot_widget.hideAxis(axis)

@staticmethod
Expand Down Expand Up @@ -203,7 +229,8 @@ def update_regression_line(self):
for line in self.reg_line_items:
self.plot_widget.removeItem(line)
self.reg_line_items.clear()
if not self.show_reg_line:
if not (self.show_reg_line
and self.master.can_draw_regresssion_line()):
return
x, y = self.master.get_coordinates_data()
if x is None:
Expand Down Expand Up @@ -238,7 +265,7 @@ class Inputs(OWDataProjectionWidget.Inputs):
class Outputs(OWDataProjectionWidget.Outputs):
features = Output("Features", AttributeList, dynamic=False)

settings_version = 4
settings_version = 5
auto_sample = Setting(True)
attr_x = ContextSetting(None)
attr_y = ContextSetting(None)
Expand All @@ -254,14 +281,21 @@ class Warning(OWDataProjectionWidget.Warning):
missing_coords = Msg(
"Plot cannot be displayed because '{}' or '{}' "
"is missing for all data points.")
no_continuous_vars = Msg("Data has no numeric variables.")

class Information(OWDataProjectionWidget.Information):
sampled_sql = Msg("Large SQL table; showing a sample.")
missing_coords = Msg(
"Points with missing '{}' or '{}' are not displayed")

def __init__(self):
self.attr_box: QGroupBox = None
self.xy_model: DomainModel = None
self.cb_attr_x: ComboBoxSearch = None
self.cb_attr_y: ComboBoxSearch = None
self.vizrank: ScatterPlotVizRank = None
self.vizrank_button: QPushButton = None
self.sampling: QGroupBox = None

self.sql_data = None # Orange.data.sql.table.SqlTable
self.attribute_selection_list = None # list of Orange.data.Variable
self.__timer = QTimer(self, interval=1200)
Expand All @@ -277,6 +311,7 @@ def _add_controls(self):
self._add_controls_axis()
self._add_controls_sampling()
super()._add_controls()
self.gui.add_widget(self.gui.JitterNumericValues, self._effects_box)
self.gui.add_widgets(
[self.gui.ShowGridLines,
self.gui.ToolTipShowsAll,
Expand All @@ -300,7 +335,7 @@ def _add_controls_axis(self):
self.attr_box = gui.vBox(self.controlArea, 'Axes',
spacing=2 if gui.is_macstyle() else 8)
dmod = DomainModel
self.xy_model = DomainModel(dmod.MIXED, valid_types=ContinuousVariable)
self.xy_model = DomainModel(dmod.MIXED, valid_types=dmod.PRIMITIVE)
self.cb_attr_x = gui.comboBox(
self.attr_box, self, "attr_x", label="Axis x:",
callback=self.set_attr_from_combo,
Expand Down Expand Up @@ -393,11 +428,6 @@ def check_data(self):
if self.auto_sample:
self.__timer.start()

if self.data is not None:
if not self.data.domain.has_continuous_attributes(True, True):
self.Warning.no_continuous_vars()
self.data = None

if self.data is not None and (len(self.data) == 0 or
len(self.data.domain.variables) == 0):
self.data = None
Expand Down Expand Up @@ -433,6 +463,12 @@ def _point_tooltip(self, point_id, skip_attrs=()):
text = "<b>{}</b><br/><br/>{}".format(text, others)
return text

def can_draw_regresssion_line(self):
return self.data is not None and \
self.data.domain is not None and \
self.attr_x.is_continuous and \
self.attr_y.is_continuous

def add_data(self, time=0.4):
if self.data and len(self.data) > 2000:
self.__timer.stop()
Expand Down Expand Up @@ -484,6 +520,7 @@ def handleNewSignals(self):
if self._domain_invalidated:
self.graph.update_axes()
self._domain_invalidated = False
self.cb_reg_line.setEnabled(self.can_draw_regresssion_line())

@Inputs.features
def set_shown_attributes(self, attributes):
Expand All @@ -505,6 +542,7 @@ def set_attr_from_combo(self):
self.xy_changed_manually.emit(self.attr_x, self.attr_y)

def attr_changed(self):
self.cb_reg_line.setEnabled(self.can_draw_regresssion_line())
self.setup_plot()
self.commit()

Expand All @@ -528,6 +566,17 @@ def get_widget_name_extension(self):
return "{} vs {}".format(self.attr_x.name, self.attr_y.name)
return None

def _get_send_report_caption(self):
return report.render_items_vert((
("Color", self._get_caption_var_name(self.attr_color)),
("Label", self._get_caption_var_name(self.attr_label)),
("Shape", self._get_caption_var_name(self.attr_shape)),
("Size", self._get_caption_var_name(self.attr_size)),
("Jittering", (self.attr_x.is_discrete or
self.attr_y.is_discrete or
self.graph.jitter_continuous) and
self.graph.jitter_size)))

@classmethod
def migrate_settings(cls, settings, version):
if version < 2 and "selection" in settings and settings["selection"]:
Expand All @@ -537,6 +586,10 @@ def migrate_settings(cls, settings, version):
settings["auto_commit"] = settings["auto_send_selection"]
if "selection_group" in settings:
settings["selection"] = settings["selection_group"]
if version < 5:
if "graph" in settings and \
"jitter_continuous" not in settings["graph"]:
settings["graph"]["jitter_continuous"] = True

@classmethod
def migrate_context(cls, context, version):
Expand Down
Loading

0 comments on commit 96fda39

Please sign in to comment.