Skip to content

Commit

Permalink
Merge pull request #253 from janezd/split-polish
Browse files Browse the repository at this point in the history
Split: Refactor for discrete values, add tests, rename
  • Loading branch information
markotoplak authored Nov 22, 2023
2 parents 0d12c42 + a74ea06 commit f08b42f
Show file tree
Hide file tree
Showing 5 changed files with 306 additions and 330 deletions.
184 changes: 0 additions & 184 deletions orangecontrib/prototypes/widgets/icons/Split.svg

This file was deleted.

33 changes: 33 additions & 0 deletions orangecontrib/prototypes/widgets/icons/TextToColumns.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

import numpy as np

from AnyQt.QtCore import Qt
Expand All @@ -13,63 +15,81 @@
from orangewidget.settings import Setting


def get_substrings(values, delimiter):
return sorted({ss.strip() for s in values for ss in s.split(delimiter)}
- {""})


class SplitColumn:
def __init__(self, data, attr, delimiter):
self.attr = attr
self.delimiter = delimiter
column = set(data.get_column(self.attr))
self.new_values = tuple(get_substrings(column, self.delimiter))

column = self.get_string_values(data, self.attr)
values = [s.split(self.delimiter) for s in column]
self.new_values = tuple(sorted({val if val else "?" for vals in
values for val in vals}))
def __call__(self, data):
column = data.get_column(self.attr)
values = [{ss.strip() for ss in s.split(self.delimiter)}
for s in column]
return {v: np.array([i for i, xs in enumerate(values) if v in xs])
for v in self.new_values}

def __eq__(self, other):
return self.attr == other.attr and self.delimiter == \
other.delimiter and self.new_values == other.new_values
return self.attr == other.attr \
and self.delimiter == other.delimiter \
and self.new_values == other.new_values

def __hash__(self):
return hash((self.attr, self.delimiter, self.new_values))

def __call__(self, data):
column = self.get_string_values(data, self.attr)
values = [set(s.split(self.delimiter)) for s in column]
shared_data = {v: [i for i, xs in enumerate(values) if v in xs] for v
in self.new_values}
return shared_data

@staticmethod
def get_string_values(data, var):
# turn discrete to string variable
column = data.get_column(var)
if var.is_discrete:
return [var.str_val(x) for x in column]
return column


class OneHotStrings(SharedComputeValue):

def __init__(self, fn, new_feature):
super().__init__(fn)
self.new_feature = new_feature

def __eq__(self, other):
return self.compute_shared == other.compute_shared \
and self.new_feature == other.new_feature

def __hash__(self):
return hash((self.compute_shared, self.new_feature))

def compute(self, data, shared_data):
indices = shared_data[self.new_feature]
col = np.zeros(len(data))
col[indices] = 1
return col

def __eq__(self, other):
return super().__eq__(other) and self.new_feature == other.new_feature

class OWSplit(OWWidget):
name = "Split"
description = "Split string variables to create discrete."
icon = "icons/Split.svg"
def __hash__(self):
return super().__hash__() ^ hash(self.new_feature)


class OneHotDiscrete:
def __init__(self, variable, delimiter, value):
self.variable = variable
self.value = value
self.delimiter = delimiter

def __call__(self, data):
column = data.get_column(self.variable).astype(float)
col = np.zeros(len(column))
col[np.isnan(column)] = np.nan
for val_idx, value in enumerate(self.variable.values):
if self.value in value.split(self.delimiter):
col[column == val_idx] = 1
return col

def __eq__(self, other):
return self.variable == other.variable \
and self.value == other.value \
and self.delimiter == other.delimiter

def __hash__(self):
return hash((self.variable, self.value, self.delimiter))


class OWTextToColumns(OWWidget):
name = "Text to Columns"
description = "Split text or categorical variables into binary indicators"
icon = "icons/TextToColumns.svg"
keywords = ["split"]
priority = 700

class Inputs:
Expand Down Expand Up @@ -129,12 +149,18 @@ def apply(self):
return
var = self.data.domain[self.attribute]

sc = SplitColumn(self.data, var, self.delimiter)
if var.is_discrete:
values = get_substrings(var.values, self.delimiter)
computer = partial(OneHotDiscrete, var, self.delimiter)
else:
sc = SplitColumn(self.data, var, self.delimiter)
values = sc.new_values
computer = partial(OneHotStrings, sc)
names = get_unique_names(self.data.domain, values, equal_numbers=False)

new_columns = tuple(DiscreteVariable(
get_unique_names(self.data.domain, v), values=("0", "1"),
compute_value=OneHotStrings(sc, v)
) for v in sc.new_values)
name, values=("0", "1"), compute_value=computer(value)
) for value, name in zip(values, names))

new_domain = Domain(
self.data.domain.attributes + new_columns,
Expand All @@ -145,5 +171,5 @@ def apply(self):


if __name__ == "__main__": # pragma: no cover
WidgetPreview(OWSplit).run(Table.from_file(
WidgetPreview(OWTextToColumns).run(Table.from_file(
"tests/orange-in-education.tab"))
Loading

0 comments on commit f08b42f

Please sign in to comment.