Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Scatter plot: Bring discrete attributes functionality back #5440

Merged
merged 1 commit into from
May 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 70 additions & 17 deletions Orange/widgets/visualize/owscatterplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from xml.sax.saxutils import escape

import numpy as np
from AnyQt.QtWidgets import QGroupBox, QPushButton
from scipy.stats import linregress
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import r2_score
Expand All @@ -11,15 +12,17 @@

import pyqtgraph as pg

from Orange.data import Table, Domain, DiscreteVariable, Variable, \
ContinuousVariable
from orangewidget.utils.combobox import ComboBoxSearch

from Orange.data import Table, Domain, DiscreteVariable, Variable
from Orange.data.sql.table import SqlTable, AUTO_DL_LIMIT
from Orange.preprocess.score import ReliefF, RReliefF

from Orange.widgets import gui
from Orange.widgets import gui, report
from Orange.widgets.io import MatplotlibFormat, MatplotlibPDFFormat
from Orange.widgets.settings import (
Setting, ContextSetting, SettingProvider, IncompatibleContext)
from Orange.widgets.utils import get_variable_values_sorted
from Orange.widgets.utils.itemmodels import DomainModel
from Orange.widgets.utils.widgetpreview import WidgetPreview
from Orange.widgets.visualize.owscatterplotgraph import OWScatterPlotBase, \
Expand Down Expand Up @@ -85,7 +88,7 @@ def score_heuristic(self):
assert self.attr_color is not None
master_domain = self.master.data.domain
vars = [v for v in chain(master_domain.variables, master_domain.metas)
if v is not self.attr_color and v.is_continuous]
if v is not self.attr_color and v.is_primitive()]
domain = Domain(attributes=vars, class_vars=self.attr_color)
data = self.master.data.transform(domain)
relief = ReliefF if isinstance(domain.class_var, DiscreteVariable) \
Expand Down Expand Up @@ -117,6 +120,7 @@ def axis_items(self):
class OWScatterPlotGraph(OWScatterPlotBase):
show_reg_line = Setting(False)
orthonormal_regression = Setting(False)
jitter_continuous = Setting(False)

def __init__(self, scatter_widget, parent):
super().__init__(scatter_widget, parent)
Expand All @@ -137,12 +141,34 @@ def update_colors(self):
super().update_colors()
self.update_regression_line()

def jitter_coordinates(self, x, y):
def get_span(attr):
if attr.is_discrete:
# Assuming the maximal jitter size is 10, a span of 4 will
# jitter by 4 * 10 / 100 = 0.4, so there will be no overlap
return 4
elif self.jitter_continuous:
return None # Let _jitter_data determine the span
else:
return 0 # No jittering
span_x = get_span(self.master.attr_x)
span_y = get_span(self.master.attr_y)
if self.jitter_size == 0 or (span_x == 0 and span_y == 0):
return x, y
return self._jitter_data(x, y, span_x, span_y)

def update_axes(self):
for axis, title in self.master.get_axes().items():
use_time = title is not None and title.is_time
for axis, var in self.master.get_axes().items():
axis_item = self.plot_widget.plotItem.getAxis(axis)
if var and var.is_discrete:
ticks = [list(enumerate(get_variable_values_sorted(var)))]
axis_item.setTicks(ticks)
else:
axis_item.setTicks(None)
use_time = var and var.is_time
self.plot_widget.plotItem.getAxis(axis).use_time(use_time)
self.plot_widget.setLabel(axis=axis, text=title or "")
if title is None:
self.plot_widget.setLabel(axis=axis, text=var or "")
if not var:
self.plot_widget.hideAxis(axis)

@staticmethod
Expand Down Expand Up @@ -203,7 +229,8 @@ def update_regression_line(self):
for line in self.reg_line_items:
self.plot_widget.removeItem(line)
self.reg_line_items.clear()
if not self.show_reg_line:
if not (self.show_reg_line
and self.master.can_draw_regresssion_line()):
return
x, y = self.master.get_coordinates_data()
if x is None:
Expand Down Expand Up @@ -238,7 +265,7 @@ class Inputs(OWDataProjectionWidget.Inputs):
class Outputs(OWDataProjectionWidget.Outputs):
features = Output("Features", AttributeList, dynamic=False)

settings_version = 4
settings_version = 5
auto_sample = Setting(True)
attr_x = ContextSetting(None)
attr_y = ContextSetting(None)
Expand All @@ -254,14 +281,21 @@ class Warning(OWDataProjectionWidget.Warning):
missing_coords = Msg(
"Plot cannot be displayed because '{}' or '{}' "
"is missing for all data points.")
no_continuous_vars = Msg("Data has no numeric variables.")

class Information(OWDataProjectionWidget.Information):
sampled_sql = Msg("Large SQL table; showing a sample.")
missing_coords = Msg(
"Points with missing '{}' or '{}' are not displayed")

def __init__(self):
self.attr_box: QGroupBox = None
self.xy_model: DomainModel = None
self.cb_attr_x: ComboBoxSearch = None
self.cb_attr_y: ComboBoxSearch = None
self.vizrank: ScatterPlotVizRank = None
self.vizrank_button: QPushButton = None
self.sampling: QGroupBox = None

self.sql_data = None # Orange.data.sql.table.SqlTable
self.attribute_selection_list = None # list of Orange.data.Variable
self.__timer = QTimer(self, interval=1200)
Expand All @@ -277,6 +311,7 @@ def _add_controls(self):
self._add_controls_axis()
self._add_controls_sampling()
super()._add_controls()
self.gui.add_widget(self.gui.JitterNumericValues, self._effects_box)
self.gui.add_widgets(
[self.gui.ShowGridLines,
self.gui.ToolTipShowsAll,
Expand All @@ -300,7 +335,7 @@ def _add_controls_axis(self):
self.attr_box = gui.vBox(self.controlArea, 'Axes',
spacing=2 if gui.is_macstyle() else 8)
dmod = DomainModel
self.xy_model = DomainModel(dmod.MIXED, valid_types=ContinuousVariable)
self.xy_model = DomainModel(dmod.MIXED, valid_types=dmod.PRIMITIVE)
self.cb_attr_x = gui.comboBox(
self.attr_box, self, "attr_x", label="Axis x:",
callback=self.set_attr_from_combo,
Expand Down Expand Up @@ -393,11 +428,6 @@ def check_data(self):
if self.auto_sample:
self.__timer.start()

if self.data is not None:
if not self.data.domain.has_continuous_attributes(True, True):
self.Warning.no_continuous_vars()
self.data = None

if self.data is not None and (len(self.data) == 0 or
len(self.data.domain.variables) == 0):
self.data = None
Expand Down Expand Up @@ -433,6 +463,12 @@ def _point_tooltip(self, point_id, skip_attrs=()):
text = "<b>{}</b><br/><br/>{}".format(text, others)
return text

def can_draw_regresssion_line(self):
return self.data is not None and \
self.data.domain is not None and \
self.attr_x.is_continuous and \
self.attr_y.is_continuous

def add_data(self, time=0.4):
if self.data and len(self.data) > 2000:
self.__timer.stop()
Expand Down Expand Up @@ -484,6 +520,7 @@ def handleNewSignals(self):
if self._domain_invalidated:
self.graph.update_axes()
self._domain_invalidated = False
self.cb_reg_line.setEnabled(self.can_draw_regresssion_line())

@Inputs.features
def set_shown_attributes(self, attributes):
Expand All @@ -505,6 +542,7 @@ def set_attr_from_combo(self):
self.xy_changed_manually.emit(self.attr_x, self.attr_y)

def attr_changed(self):
self.cb_reg_line.setEnabled(self.can_draw_regresssion_line())
self.setup_plot()
self.commit()

Expand All @@ -528,6 +566,17 @@ def get_widget_name_extension(self):
return "{} vs {}".format(self.attr_x.name, self.attr_y.name)
return None

def _get_send_report_caption(self):
return report.render_items_vert((
("Color", self._get_caption_var_name(self.attr_color)),
("Label", self._get_caption_var_name(self.attr_label)),
("Shape", self._get_caption_var_name(self.attr_shape)),
("Size", self._get_caption_var_name(self.attr_size)),
("Jittering", (self.attr_x.is_discrete or
self.attr_y.is_discrete or
self.graph.jitter_continuous) and
self.graph.jitter_size)))

@classmethod
def migrate_settings(cls, settings, version):
if version < 2 and "selection" in settings and settings["selection"]:
Expand All @@ -537,6 +586,10 @@ def migrate_settings(cls, settings, version):
settings["auto_commit"] = settings["auto_send_selection"]
if "selection_group" in settings:
settings["selection"] = settings["selection_group"]
if version < 5:
VesnaT marked this conversation as resolved.
Show resolved Hide resolved
if "graph" in settings and \
"jitter_continuous" not in settings["graph"]:
settings["graph"]["jitter_continuous"] = True

@classmethod
def migrate_context(cls, context, version):
Expand Down
Loading