Skip to content

Commit

Permalink
Merge pull request #5428 from VesnaT/dbscan_normalize
Browse files Browse the repository at this point in the history
[ENH] DBSCAN: Optional normalization
  • Loading branch information
markotoplak authored May 7, 2021
2 parents 6f25bda + 5a0f03e commit 4376a88
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 6 deletions.
27 changes: 21 additions & 6 deletions Orange/widgets/unsupervised/owdbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class Error(widget.OWWidget.Error):
min_samples = Setting(4)
eps = Setting(0.5)
metric_idx = Setting(0)
normalize = Setting(True)
auto_commit = Setting(True)
k_distances = None
cut_point = None
Expand All @@ -102,6 +103,8 @@ def __init__(self):
gui.comboBox(box, self, "metric_idx",
items=list(zip(*self.METRICS))[0],
callback=self._metirc_changed)
gui.checkBox(box, self, "normalize", "Normalize features",
callback=self._on_normalize_changed)

gui.auto_apply(self.buttonsArea, self, "auto_commit")
gui.rubber(self.controlArea)
Expand Down Expand Up @@ -161,9 +164,9 @@ def _compute_cut_point(self):
self.cut_point = int(DEFAULT_CUT_POINT * len(self.k_distances))
self.eps = self.k_distances[self.cut_point]

if self.eps < EPS_BOTTOM_LIMIT:
self.eps = np.min(
self.k_distances[self.k_distances >= EPS_BOTTOM_LIMIT])
mask = self.k_distances >= EPS_BOTTOM_LIMIT
if self.eps < EPS_BOTTOM_LIMIT and sum(mask):
self.eps = np.min(self.k_distances[mask])
self.cut_point = self._find_nearest_dist(self.eps)

@Inputs.data
Expand All @@ -180,13 +183,18 @@ def set_data(self, data):
if self.data is None:
return

# preprocess data
for pp in PREPROCESSORS:
self.data_normalized = pp(self.data_normalized)
self._preprocess_data()

self._compute_and_plot()
self.unconditional_commit()

def _preprocess_data(self):
self.data_normalized = self.data
for pp in PREPROCESSORS:
if isinstance(pp, Normalize) and not self.normalize:
continue
self.data_normalized = pp(self.data_normalized)

def send_data(self):
model = self.model

Expand Down Expand Up @@ -248,6 +256,13 @@ def _min_samples_changed(self):
self._compute_and_plot(cut_point=self.cut_point)
self._invalidate()

def _on_normalize_changed(self):
if not self.data:
return
self._preprocess_data()
self._compute_and_plot()
self._invalidate()


if __name__ == "__main__":
a = QApplication(sys.argv)
Expand Down
46 changes: 46 additions & 0 deletions Orange/widgets/unsupervised/tests/test_owdbscan.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# pylint: disable=protected-access
import unittest

import numpy as np
from scipy.sparse import csr_matrix, csc_matrix

from Orange.data import Table
from Orange.clustering import DBSCAN
from Orange.distance import Euclidean
from Orange.preprocess import Normalize, Continuize, SklImpute
from Orange.widgets.tests.base import WidgetTest
from Orange.widgets.tests.utils import simulate, possible_duplicate_table
from Orange.widgets.unsupervised.owdbscan import OWDBSCAN, get_kth_distances
Expand Down Expand Up @@ -226,3 +230,45 @@ def test_missing_data(self):
self.send_signal(w.Inputs.data, self.iris)
output = self.get_output(w.Outputs.annotated_data)
self.assertTupleEqual((150, 1), output[:, "Cluster"].metas.shape)

def test_normalize_data(self):
# not normalized
self.widget.controls.normalize.setChecked(False)

data = Table("heart_disease")
self.send_signal(self.widget.Inputs.data, data)

kwargs = {"eps": self.widget.eps,
"min_samples": self.widget.min_samples,
"metric": "euclidean"}
clusters = DBSCAN(**kwargs)(data)

output = self.get_output(self.widget.Outputs.annotated_data)
output_clusters = output.metas[:, 0]
output_clusters[np.isnan(output_clusters)] = -1
np.testing.assert_array_equal(output_clusters, clusters)

# normalized
self.widget.controls.normalize.setChecked(True)

kwargs = {"eps": self.widget.eps,
"min_samples": self.widget.min_samples,
"metric": "euclidean"}
for pp in (Continuize(), Normalize(), SklImpute()):
data = pp(data)
clusters = DBSCAN(**kwargs)(data)

output = self.get_output(self.widget.Outputs.annotated_data)
output_clusters = output.metas[:, 0]
output_clusters[np.isnan(output_clusters)] = -1
np.testing.assert_array_equal(output_clusters, clusters)

def test_normalize_changed(self):
self.send_signal(self.widget.Inputs.data, self.iris)
simulate.combobox_run_through_all(self.widget.controls.metric_idx)
self.widget.controls.normalize.setChecked(False)
simulate.combobox_run_through_all(self.widget.controls.metric_idx)


if __name__ == '__main__':
unittest.main()
1 change: 1 addition & 0 deletions Orange/widgets/utils/slidergraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def clear_plot(self):
This function clears the plot and removes data.
"""
self.clear()
self.setRange(xRange=(0.0, 1.0), yRange=(0.0, 1.0))
self.plot_horlabel = []
self.plot_horline = []
self._line = None
Expand Down

0 comments on commit 4376a88

Please sign in to comment.