From e29173698545bc35b17c5241212a12d4a9a5157b Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Tue, 14 Dec 2021 07:15:49 -0500 Subject: [PATCH] Fix import errors without dev dependencies (#214) * add ci test for import without dev dependencies * fix tensorflow import error --- .github/workflows/ci_test.yml | 31 +++++++++++++++++++++++++++++++ elegy/model/model_core.py | 5 ++++- 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci_test.yml b/.github/workflows/ci_test.yml index a36fb6a5..55a2f08d 100644 --- a/.github/workflows/ci_test.yml +++ b/.github/workflows/ci_test.yml @@ -54,3 +54,34 @@ jobs: - name: Test Examples run: bash scripts/test-examples.sh + + test-import: + name: Test Import without Dev Dependencies + if: ${{ !contains(github.event.pull_request.title, 'WIP') }} + runs-on: ubuntu-latest + strategy: + matrix: + # python-version: [3.9] + python-version: [3.7, 3.8, 3.9] + steps: + - name: Check out the code + uses: actions/checkout@v2 + with: + fetch-depth: 1 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Poetry + uses: snok/install-poetry@v1.1.1 + with: + version: 1.1.4 + + - name: Install Dependencies + run: | + poetry config virtualenvs.create false + poetry install --no-dev + + - name: Import Elegy + run: python -c "import elegy" diff --git a/elegy/model/model_core.py b/elegy/model/model_core.py index 8de370c2..facf170c 100644 --- a/elegy/model/model_core.py +++ b/elegy/model/model_core.py @@ -10,7 +10,6 @@ import jax.numpy as jnp import numpy as np import treex as tx -from jax.experimental import jax2tf from elegy import types, utils @@ -18,8 +17,10 @@ try: import tensorflow as tf + from jax.experimental import jax2tf except ImportError: tf = None + jax2tf = None A = tp.TypeVar("A") M = tp.TypeVar("M", bound="ModelCore") @@ -727,6 +728,8 @@ def saved_model( if model_utils.convert_and_save_model is None: raise ImportError(f"Could not import tensorflow.") + assert jax2tf is not None + if isinstance(batch_size, int): batch_size = [batch_size]