Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Impute: Allow setting a default value for all numeric and time variables #5102

Merged
merged 5 commits into from
Jan 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions Orange/data/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,27 +270,6 @@ def read(self):
raise IOError("Couldn't load spreadsheet from " + self.filename)
return table

@classmethod
def write_file(cls, filename, data):
vars = list(chain((ContinuousVariable('_w'),) if data.has_weights() else (),
data.domain.attributes,
data.domain.class_vars,
data.domain.metas))
formatters = [cls.formatter(v) for v in vars]
zipped_list_data = zip(data.W if data.W.ndim > 1 else data.W[:, np.newaxis],
data.X,
data.Y if data.Y.ndim > 1 else data.Y[:, np.newaxis],
data.metas)
headers = cls.header_names(data)
workbook = xlsxwriter.Workbook(filename)
sheet = workbook.add_worksheet()
for c, header in enumerate(headers):
sheet.write(0, c, header)
for i, row in enumerate(zipped_list_data, 1):
for j, (fmt, v) in enumerate(zip(formatters, flatten(row))):
sheet.write(i, j, fmt(v))
workbook.close()


class ExcelReader(_BaseExcelReader):
"""Reader for .xlsx files"""
Expand Down Expand Up @@ -332,6 +311,27 @@ def _get_active_sheet(self) -> openpyxl.worksheet.worksheet.Worksheet:
else:
return self.workbook.active

@classmethod
def write_file(cls, filename, data):
vars = list(chain((ContinuousVariable('_w'),) if data.has_weights() else (),
data.domain.attributes,
data.domain.class_vars,
data.domain.metas))
formatters = [cls.formatter(v) for v in vars]
zipped_list_data = zip(data.W if data.W.ndim > 1 else data.W[:, np.newaxis],
data.X,
data.Y if data.Y.ndim > 1 else data.Y[:, np.newaxis],
data.metas)
headers = cls.header_names(data)
workbook = xlsxwriter.Workbook(filename)
sheet = workbook.add_worksheet()
for c, header in enumerate(headers):
sheet.write(0, c, header)
for i, row in enumerate(zipped_list_data, 1):
for j, (fmt, v) in enumerate(zip(formatters, flatten(row))):
sheet.write(i, j, fmt(v))
workbook.close()


class XlsReader(_BaseExcelReader):
"""Reader for .xls files"""
Expand Down
41 changes: 39 additions & 2 deletions Orange/preprocess/impute.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .transformation import Transformation, Lookup

__all__ = ["ReplaceUnknowns", "Average", "DoNotImpute", "DropInstances",
"Model", "AsValue", "Random", "Default"]
"Model", "AsValue", "Random", "Default", "FixedValueByType"]


class ReplaceUnknowns(Transformation):
Expand Down Expand Up @@ -113,6 +113,10 @@ def __call__(self, data, variable, value=None):
a.to_sql = ImputeSql(variable, value)
return a

@staticmethod
def supports_variable(variable):
return variable.is_primitive()


class ImputeSql(Reprable):
def __init__(self, var, default):
Expand All @@ -124,7 +128,7 @@ def __call__(self):


class Default(BaseImputeMethod):
name = "Value"
name = "Fixed value"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A new class was introduced below with name = "Fixed value". Was this one intentionally changed to the same name too (short names remain different)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It's more informative, looks better in the widget.

short_name = "value"
description = ""
columns_only = True
Expand All @@ -142,6 +146,32 @@ def copy(self):
return Default(self.default)


class FixedValueByType(BaseImputeMethod):
name = "Fixed value"
short_name = "Fixed Value"
format = "{var.name}"

def __init__(self,
default_discrete=np.nan, default_continuous=np.nan,
default_string=None, default_time=np.nan):
# If you change the order of args or in dict, also fix method copy
self.defaults = {
Orange.data.DiscreteVariable: default_discrete,
Orange.data.ContinuousVariable: default_continuous,
Orange.data.StringVariable: default_string,
Orange.data.TimeVariable: default_time
}

def __call__(self, data, variable, *, default=None):
variable = data.domain[variable]
if default is None:
default = self.defaults[type(variable)]
return variable.copy(compute_value=ReplaceUnknowns(variable, default))

def copy(self):
return FixedValueByType(*self.defaults.values())


class ReplaceUnknownsModel(Reprable):
"""
Replace unknown values with predicted values using a `Orange.base.Model`
Expand Down Expand Up @@ -272,6 +302,9 @@ def __call__(self, data, variable):
else:
raise TypeError(type(variable))

@staticmethod
def supports_variable(variable):
return variable.is_primitive()

class ReplaceUnknownsRandom(Transformation):
"""
Expand Down Expand Up @@ -354,3 +387,7 @@ def __call__(self, data, variable):
dist[1, :] += 1 / dist.shape[1]
return variable.copy(
compute_value=ReplaceUnknownsRandom(variable, dist))

@staticmethod
def supports_variable(variable):
return variable.is_primitive()
63 changes: 61 additions & 2 deletions Orange/preprocess/tests/test_impute.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import unittest

from Orange.data import DiscreteVariable, ContinuousVariable
from Orange.preprocess.impute import ReplaceUnknownsRandom, ReplaceUnknowns
import numpy as np

from Orange.data import \
Domain, Table, \
DiscreteVariable, ContinuousVariable, TimeVariable, StringVariable
from Orange.preprocess.impute import ReplaceUnknownsRandom, ReplaceUnknowns, \
FixedValueByType
from Orange.statistics.distribution import Discrete


Expand Down Expand Up @@ -52,5 +57,59 @@ def test_equality(self):
self.assertNotEqual(hash(t1), hash(t1a))


class TestFixedValuesByType(unittest.TestCase):
def setUp(self):
domain = Domain(
[DiscreteVariable("d", values=tuple("abc")),
ContinuousVariable("c"),
TimeVariable("t")],
[],
[StringVariable("s")]
)
n = np.nan
self.data = Table(
domain,
np.array([[1, n, 15], [n, 42, n]]),
np.empty((2, 0)),
np.array([["foo"], [""]]))

def test_none_defined(self):
d, c, t = self.data.domain.attributes
s, = self.data.domain.metas

imputer = FixedValueByType()
for var in (d, c, t):
imp = imputer(self.data, var)
self.assertIsInstance(imp.compute_value, ReplaceUnknowns)
self.assertTrue(np.isnan(imp.compute_value.value))
imp = imputer(self.data, s)
self.assertIsInstance(imp.compute_value, ReplaceUnknowns)
self.assertIsNone(imp.compute_value.value)

def test_all_defined(self):
d, c, t = self.data.domain.attributes
s, = self.data.domain.metas

imputer = FixedValueByType(
default_discrete=1, default_continuous=42,
default_string="foo", default_time=3.14)

self.assertEqual(imputer(self.data, d).compute_value.value, 1)
self.assertEqual(imputer(self.data, c).compute_value.value, 42)
self.assertEqual(imputer(self.data, t).compute_value.value, 3.14)
self.assertEqual(imputer(self.data, s).compute_value.value, "foo")

def test_with_default(self):
s, = self.data.domain.metas

imputer = FixedValueByType(
default_discrete=1, default_continuous=42,
default_string="foo", default_time=3.14)

self.assertEqual(
imputer(self.data, s, default="bar").compute_value.value,
"bar")


if __name__ == "__main__":
unittest.main()
18 changes: 9 additions & 9 deletions Orange/tests/test_xlsx_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ def get_dataset(name):


def get_xlsx_reader(name: str) -> io.ExcelReader:
return io.ExcelReader(get_dataset(name))
return io.ExcelReader(get_dataset(name + ".xlsx"))


def get_xls_reader(name: str) -> io.XlsReader:
return io.XlsReader(get_dataset(name))
return io.XlsReader(get_dataset(name + ".xls"))


def read_file(reader: Callable, name: str) -> Table:
Expand All @@ -37,7 +37,7 @@ def wrapper(self):

class TestExcelReader(unittest.TestCase):
def test_read_round_floats(self):
table = read_file(get_xlsx_reader, "round_floats.xlsx")
table = read_file(get_xlsx_reader, "round_floats")
domain = table.domain
self.assertIsNone(domain.class_var)
self.assertEqual(len(domain.metas), 0)
Expand All @@ -50,7 +50,7 @@ def test_read_round_floats(self):
class TestExcelHeader0(unittest.TestCase):
@test_xlsx_xls
def test_read(self, reader: Callable[[str], io.FileFormat]):
table = read_file(reader, "header_0.xlsx")
table = read_file(reader, "header_0")
domain = table.domain
self.assertIsNone(domain.class_var)
self.assertEqual(len(domain.metas), 0)
Expand All @@ -68,13 +68,13 @@ def test_read(self, reader: Callable[[str], io.FileFormat]):
class TextExcelSheets(unittest.TestCase):
@test_xlsx_xls
def test_sheets(self, reader: Callable[[str], io.FileFormat]):
reader = reader("header_0_sheet.xlsx")
reader = reader("header_0_sheet")
self.assertSequenceEqual(reader.sheets,
["Sheet1", "my_sheet", "Sheet3"])

@test_xlsx_xls
def test_named_sheet(self, reader: Callable[[str], io.FileFormat]):
reader = reader("header_0_sheet.xlsx")
reader = reader("header_0_sheet")
reader.select_sheet("my_sheet")
table = reader.read()
self.assertEqual(len(table.domain.attributes), 4)
Expand All @@ -96,7 +96,7 @@ def test_named_sheet_table_xls(self):
class TestExcelHeader1(unittest.TestCase):
@test_xlsx_xls
def test_no_flags(self, reader: Callable[[str], io.FileFormat]):
table = read_file(reader, "header_1_no_flags.xlsx")
table = read_file(reader, "header_1_no_flags")
domain = table.domain
self.assertEqual(len(domain.metas), 0)
self.assertEqual(len(domain.attributes), 4)
Expand All @@ -115,7 +115,7 @@ def test_no_flags(self, reader: Callable[[str], io.FileFormat]):

@test_xlsx_xls
def test_flags(self, reader: Callable[[str], io.FileFormat]):
table = read_file(reader, "header_1_flags.xlsx")
table = read_file(reader, "header_1_flags")
domain = table.domain

self.assertEqual(len(domain.attributes), 1)
Expand Down Expand Up @@ -146,7 +146,7 @@ def test_flags(self, reader: Callable[[str], io.FileFormat]):
class TestExcelHeader3(unittest.TestCase):
@test_xlsx_xls
def test_read(self, reader: Callable[[str], io.FileFormat]):
table = read_file(reader, "header_3.xlsx")
table = read_file(reader, "header_3")
domain = table.domain

self.assertEqual(len(domain.attributes), 2)
Expand Down
Binary file added Orange/tests/xlsx_files/header_0.xls
Binary file not shown.
Binary file added Orange/tests/xlsx_files/header_1_flags.xls
Binary file not shown.
Binary file added Orange/tests/xlsx_files/header_1_no_flags.xls
Binary file not shown.
Binary file added Orange/tests/xlsx_files/header_3.xls
Binary file not shown.
Loading