Skip to content

Commit

Permalink
Merge pull request #6114 from markotoplak/compute-value-warn
Browse files Browse the repository at this point in the history
Extend warnings for missing __eq__ within Variables
  • Loading branch information
janezd authored Sep 1, 2022
2 parents b0485ba + bdaee7d commit fd70c60
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 3 deletions.
36 changes: 35 additions & 1 deletion Orange/data/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from Orange.data import Domain, ContinuousVariable
from Orange.data.util import get_unique_names, get_unique_names_duplicates, \
get_unique_names_domain, one_hot, sanitized_name
get_unique_names_domain, one_hot, sanitized_name, redefines_eq_and_hash


class TestGetUniqueNames(unittest.TestCase):
Expand Down Expand Up @@ -309,5 +309,39 @@ def test_sanitized_name(self):
self.assertEqual(sanitized_name("1 Foo Bar"), "_1_Foo_Bar")


class TestRedefinesEqAndHash(unittest.TestCase):

class Valid:
def __eq__(self, other):
pass

def __hash__(self):
pass

class Subclass(Valid):
pass

class OnlyEq:
def __eq__(self, other):
pass

class OnlyHash:
def __hash__(self):
pass

def test_valid(self):
self.assertTrue(redefines_eq_and_hash(self.Valid))
self.assertTrue(redefines_eq_and_hash(self.Valid()))

def test_subclass(self):
self.assertFalse(redefines_eq_and_hash(self.Subclass))

def test_only_eq(self):
self.assertFalse(redefines_eq_and_hash(self.OnlyEq))

def test_only_hash(self):
self.assertFalse(redefines_eq_and_hash(self.OnlyHash))


if __name__ == "__main__":
unittest.main()
18 changes: 18 additions & 0 deletions Orange/data/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,24 @@ class Invalid:
ContinuousVariable("x", compute_value=Invalid())
self.assertNotEqual(warns, [])

with warnings.catch_warnings(record=True) as warns:

class MissingHash:
def __eq__(self, other):
return self is other

ContinuousVariable("x", compute_value=MissingHash())
self.assertNotEqual(warns, [])

with warnings.catch_warnings(record=True) as warns:

class MissingEq:
def __hash__(self):
return super().__hash__(self)

ContinuousVariable("x", compute_value=MissingEq())
self.assertNotEqual(warns, [])


def variabletest(varcls):
def decorate(cls):
Expand Down
33 changes: 33 additions & 0 deletions Orange/data/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Data-manipulation utilities.
"""
import re
import types
import warnings
from collections import Counter
from itertools import chain, count
from typing import Callable, Union, List, Type
Expand Down Expand Up @@ -72,6 +74,12 @@ class SharedComputeValue:
def __init__(self, compute_shared, variable=None):
self.compute_shared = compute_shared
self.variable = variable
if compute_shared is not None \
and not isinstance(compute_shared, (types.BuiltinFunctionType,
types.FunctionType)) \
and not redefines_eq_and_hash(compute_shared):
warnings.warn(f"{type(compute_shared).__name__} should define"
f"__eq__ and __hash__ to be used for compute_shared")

def __call__(self, data, shared_data=None):
"""Fallback if common parts are not passed."""
Expand All @@ -85,6 +93,14 @@ def compute(self, data, shared_data):
Subclasses need to implement this function."""
raise NotImplementedError

def __eq__(self, other):
return type(self) is type(other) \
and self.compute_shared == other.compute_shared \
and self.variable == other.variable

def __hash__(self):
return hash((type(self), self.compute_shared, self.variable))


def vstack(arrays):
"""vstack that supports sparse and dense arrays
Expand Down Expand Up @@ -307,3 +323,20 @@ def sanitized_name(name: str) -> str:
if sanitized[0].isdigit():
sanitized = "_" + sanitized
return sanitized


def redefines_eq_and_hash(this):
"""
Check if the passed object (or class) redefines __eq__ and __hash__.
Args:
this: class or object
"""
if not isinstance(this, type):
this = type(this)

# if only __eq__ is defined, __hash__ is set to None
if this.__hash__ is None:
return False

return "__hash__" in this.__dict__ and "__eq__" in this.__dict__
5 changes: 3 additions & 2 deletions Orange/data/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import scipy.sparse as sp

from Orange.data import _variable
from Orange.data.util import redefines_eq_and_hash
from Orange.util import Registry, Reprable, OrangeDeprecationWarning


Expand Down Expand Up @@ -368,11 +369,11 @@ def __init__(self, name="", compute_value=None, *, sparse=False):
warnings.warn("Variable must have a name", OrangeDeprecationWarning,
stacklevel=3)
self._name = name

if compute_value is not None \
and not isinstance(compute_value, (types.BuiltinFunctionType,
types.FunctionType)) \
and (type(compute_value).__eq__ is object.__eq__
or compute_value.__hash__ is object.__hash__):
and not redefines_eq_and_hash(compute_value):
warnings.warn(f"{type(compute_value).__name__} should define"
f"__eq__ and __hash__ to be used for compute_value")

Expand Down
47 changes: 47 additions & 0 deletions Orange/tests/test_data_util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
import warnings
from unittest.mock import Mock

import numpy as np
Expand Down Expand Up @@ -72,3 +73,49 @@ def test_single_call(self):
#test with descendants of table
DummyTable.from_table(c.domain, data)
self.assertEqual(obj.compute_shared.call_count, 4)

def test_compute_shared_eq_warning(self):
with warnings.catch_warnings(record=True) as warns:
DummyPlus(compute_shared=lambda *_: 42)

class Valid:
def __eq__(self, other):
pass

def __hash__(self):
pass

DummyPlus(compute_shared=Valid())
self.assertEqual(warns, [])

class Invalid:
pass

DummyPlus(compute_shared=Invalid())
self.assertNotEqual(warns, [])

with warnings.catch_warnings(record=True) as warns:

class MissingHash:
def __eq__(self, other):
pass

DummyPlus(compute_shared=MissingHash())
self.assertNotEqual(warns, [])

with warnings.catch_warnings(record=True) as warns:

class MissingEq:
def __hash__(self):
pass

DummyPlus(compute_shared=MissingEq())
self.assertNotEqual(warns, [])

with warnings.catch_warnings(record=True) as warns:

class Subclass(Valid):
pass

DummyPlus(compute_shared=Subclass())
self.assertNotEqual(warns, [])

0 comments on commit fd70c60

Please sign in to comment.