Skip to content

Commit

Permalink
Remove code violating deepcopy semantics
Browse files Browse the repository at this point in the history
Add comments
  • Loading branch information
shadeMe committed Jul 6, 2023
1 parent 74288ba commit 88f983c
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 12 deletions.
6 changes: 3 additions & 3 deletions curated_tokenizers/_bbpe.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ cdef class ByteBPEProcessor:
return ByteBPEProcessor(vocab=self.vocab, merges=self.merges)

def __deepcopy__(self, memo):
result = ByteBPEProcessor(vocab=self.vocab, merges=self.merges)
memo[id(self)] = result
return result
# We don't need a deepcopy of the vocab and merges dicts as their
# contents will be copied into a backing store in the c'tor.
return ByteBPEProcessor(vocab=self.vocab, merges=self.merges)

@staticmethod
def load_from_files(*, vocab: Path, merges: Path) -> ByteBPEProcessor:
Expand Down
1 change: 1 addition & 0 deletions curated_tokenizers/_spp.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ cdef class SentencePieceProcessor:
"""
cdef SentencePieceProcessor processor = SentencePieceProcessor.__new__(SentencePieceProcessor)
if len(protobuf) == 0:
# SentencePiece returns an empty protobuf for uninitialized models.
return processor
cdef string_view protobuf_view = string_view(protobuf, len(protobuf))
_check_status(deref(processor.spp).LoadFromSerializedProto(protobuf_view))
Expand Down
12 changes: 3 additions & 9 deletions curated_tokenizers/_wordpiece.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,11 @@ cdef class WordPieceProcessor:
self._pieces.add_piece(byte_array, is_initial)

def __copy__(self):
cls = self.__class__
data = self.to_list()
result = cls(data)
return result
# This is essentially a deepcopy, but there's no better way to do it.
return WordPieceProcessor(self.to_list())

def __deepcopy__(self, memo):
cls = self.__class__
data = self.to_list()
result = cls(data)
memo[id(self)] = result
return result
return WordPieceProcessor(self.to_list())

def encode(self, token: str) -> Tuple[List[int], List[str]]:
"""
Expand Down

0 comments on commit 88f983c

Please sign in to comment.