From 814d04d7c98c50ae1b79ccea58ace7a651c2fc11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Thu, 8 Feb 2024 12:11:30 +0100 Subject: [PATCH] black --- .github/workflows/test.yml | 2 +- curated_transformers/models/falcon/layer.py | 8 +++++--- curated_transformers/tests/models/util.py | 4 +--- .../tokenizers/legacy/legacy_tokenizer.py | 6 ++---- 4 files changed, 9 insertions(+), 11 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f5e901c3..e775e1e2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -14,7 +14,7 @@ jobs: - name: Configure Python version uses: actions/setup-python@v4 with: - python-version: "3.8" + python-version: "3.11" architecture: x64 - name: black diff --git a/curated_transformers/models/falcon/layer.py b/curated_transformers/models/falcon/layer.py index 88d30135..136b2700 100644 --- a/curated_transformers/models/falcon/layer.py +++ b/curated_transformers/models/falcon/layer.py @@ -73,9 +73,11 @@ def __init__( n_key_value_heads=attention_config.n_key_value_heads, ), rotary_embeds=rotary_embeds, - qkv_mode=QkvMode.MERGED_SPLIT_AFTER - if attention_config.n_key_value_heads == 1 - else QkvMode.MERGED_SPLIT_BEFORE, + qkv_mode=( + QkvMode.MERGED_SPLIT_AFTER + if attention_config.n_key_value_heads == 1 + else QkvMode.MERGED_SPLIT_BEFORE + ), use_bias=attention_config.use_bias, device=device, ) diff --git a/curated_transformers/tests/models/util.py b/curated_transformers/tests/models/util.py index 3f7913bf..279d3a56 100644 --- a/curated_transformers/tests/models/util.py +++ b/curated_transformers/tests/models/util.py @@ -58,9 +58,7 @@ class JITMethod(Enum): TorchCompile = 1 TorchScriptTrace = 2 - def convert( - self, model: Module, with_torch_sdp: bool, *args - ) -> Tuple[ + def convert(self, model: Module, with_torch_sdp: bool, *args) -> Tuple[ Union[Module, torch.ScriptModule], Callable[[Union[ModelOutput, Dict[str, torch.Tensor]]], Tensor], ]: diff --git a/curated_transformers/tokenizers/legacy/legacy_tokenizer.py b/curated_transformers/tokenizers/legacy/legacy_tokenizer.py index 12de0014..9c05cd3a 100644 --- a/curated_transformers/tokenizers/legacy/legacy_tokenizer.py +++ b/curated_transformers/tokenizers/legacy/legacy_tokenizer.py @@ -155,12 +155,10 @@ def _convert_strings( @abstractmethod def _decode( self, input: Iterable[Iterable[int]], skip_special_pieces: bool - ) -> List[str]: - ... + ) -> List[str]: ... @abstractmethod - def _encode(self, input: Iterable[MergedInputChunks]) -> PiecesWithIds: - ... + def _encode(self, input: Iterable[MergedInputChunks]) -> PiecesWithIds: ... class AddBosEosPreEncoder(PreEncoder):