Skip to content

Commit

Permalink
domain: fix unpickling with circles
Browse files Browse the repository at this point in the history
Unpickling dictionaries that include objects that redefine __hash__
as keys is sometimes problematic (when said objects do not have __dict__
filled-in yet in but are used as  keys in a restored dictionary).
  • Loading branch information
markotoplak committed Jan 26, 2023
1 parent d67cc59 commit 03168a2
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 8 deletions.
38 changes: 30 additions & 8 deletions Orange/data/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,17 +164,38 @@ def __init__(self, attributes, class_vars=None, metas=None, source=None):
if not all(var.is_primitive() for var in self._variables):
raise TypeError("variables must be primitive")

self._indices = dict(chain.from_iterable(
((var, idx), (var.name, idx), (idx, idx))
for idx, var in enumerate(self._variables)))
self._indices.update(chain.from_iterable(
((var, -1-idx), (var.name, -1-idx), (-1-idx, -1-idx))
for idx, var in enumerate(self.metas)))
self._indices = None

self.anonymous = False

self._hash = None # cache for __hash__()

def _ensure_indices(self):
if self._indices is None:
self._indices = dict(chain.from_iterable(
((var, idx), (var.name, idx), (idx, idx))
for idx, var in enumerate(self._variables)))
self._indices.update(chain.from_iterable(
((var, -1-idx), (var.name, -1-idx), (-1-idx, -1-idx))
for idx, var in enumerate(self.metas)))

def __setstate__(self, state):
self.__dict__.update(state)
self._variables = self.attributes + self.class_vars
self._indices = None
self._hash = None

def __getstate__(self):
# Do not pickle dictionaries because unpickling dictionaries that
# include objects that redefine __hash__ as keys is sometimes problematic
# (when said objects do not have __dict__ filled yet in but are used as
# keys in a restored dictionary).
state = self.__dict__.copy()
del state["_variables"]
del state["_indices"]
del state["_hash"]
return state

# noinspection PyPep8Naming
@classmethod
def from_numpy(cls, X, Y=None, metas=None):
Expand Down Expand Up @@ -289,7 +310,7 @@ def __getitem__(self, idx):
"""
if isinstance(idx, slice):
return self._variables[idx]

self._ensure_indices()
index = self._indices.get(idx)
if index is None:
var = self._get_equivalent(idx)
Expand All @@ -306,6 +327,7 @@ def __contains__(self, item):
Return `True` if the item (`str`, `int`, :class:`Variable`) is
in the domain.
"""
self._ensure_indices()
return item in self._indices or self._get_equivalent(item) is not None

def __iter__(self):
Expand Down Expand Up @@ -334,7 +356,7 @@ def index(self, var):
Return the index of the given variable or meta attribute, represented
with an instance of :class:`Variable`, `int` or `str`.
"""

self._ensure_indices()
idx = self._indices.get(var)
if idx is not None:
return idx
Expand Down
1 change: 1 addition & 0 deletions Orange/tests/test_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ def test_get_item_similar_vars(self):
metas=[var1, var2]
)
# pylint: disable=protected-access
domain._ensure_indices()
self.assertDictEqual(
{-1: -1, -2: -2, var1: -1, var2: -2, var1.name: -1, var2.name: -2},
domain._indices
Expand Down

0 comments on commit 03168a2

Please sign in to comment.