diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 79978bd..28ef111 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,7 +13,7 @@ repos: hooks: - id: black-jupyter - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.5.6" + rev: "v0.6.3" hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] diff --git a/docs/benchmarks.ipynb b/docs/benchmarks.ipynb index 6908583..a1b0f8a 100644 --- a/docs/benchmarks.ipynb +++ b/docs/benchmarks.ipynb @@ -118,13 +118,12 @@ "source": [ "from functools import partial\n", "\n", - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "\n", + "import celerite2\n", + "import george\n", "import jax\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", "\n", - "import george\n", - "import celerite2\n", "import tinygp\n", "\n", "jax.config.update(\"jax_enable_x64\", True)\n", diff --git a/docs/tutorials/derivative.ipynb b/docs/tutorials/derivative.ipynb index ca9ec08..b408543 100644 --- a/docs/tutorials/derivative.ipynb +++ b/docs/tutorials/derivative.ipynb @@ -60,9 +60,8 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", "import matplotlib.pyplot as plt\n", - "\n", + "import numpy as np\n", "\n", "X = np.linspace(0.0, 5 * np.pi, 50)\n", "y = np.concatenate(\n", @@ -97,10 +96,11 @@ "metadata": {}, "outputs": [], "source": [ - "import tinygp\n", "import jax\n", "import jax.numpy as jnp\n", "\n", + "import tinygp\n", + "\n", "jax.config.update(\"jax_enable_x64\", True)\n", "\n", "\n", @@ -169,10 +169,10 @@ " )[0]\n", "\n", " plt.figure()\n", - " plt.plot(dt, k00, label=\"$\\mathrm{cov}(f,\\,f)$\", lw=1)\n", - " plt.plot(dt, k01, label=\"$\\mathrm{cov}(f,\\,\\dot{f})$\", lw=1)\n", - " plt.plot(dt, k10, label=\"$\\mathrm{cov}(\\dot{f},\\,f)$\", lw=1)\n", - " plt.plot(dt, k11, label=\"$\\mathrm{cov}(\\dot{f},\\,\\dot{f})$\", lw=1)\n", + " plt.plot(dt, k00, label=r\"$\\mathrm{cov}(f,\\,f)$\", lw=1)\n", + " plt.plot(dt, k01, label=r\"$\\mathrm{cov}(f,\\,\\dot{f})$\", lw=1)\n", + " plt.plot(dt, k10, label=r\"$\\mathrm{cov}(\\dot{f},\\,f)$\", lw=1)\n", + " plt.plot(dt, k11, label=r\"$\\mathrm{cov}(\\dot{f},\\,\\dot{f})$\", lw=1)\n", " plt.legend()\n", " plt.xlabel(r\"$\\Delta t$\")\n", " plt.xlim(dt.min(), dt.max())\n", diff --git a/docs/tutorials/geometry.ipynb b/docs/tutorials/geometry.ipynb index b155db2..0f58103 100644 --- a/docs/tutorials/geometry.ipynb +++ b/docs/tutorials/geometry.ipynb @@ -47,11 +47,12 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", "import jax\n", "import jax.numpy as jnp\n", - "from tinygp import kernels, GaussianProcess\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "from tinygp import GaussianProcess, kernels\n", "\n", "jax.config.update(\"jax_enable_x64\", True)\n", "\n", diff --git a/docs/tutorials/intro.ipynb b/docs/tutorials/intro.ipynb index 0ad9d41..a0383c2 100644 --- a/docs/tutorials/intro.ipynb +++ b/docs/tutorials/intro.ipynb @@ -82,8 +82,8 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", "import matplotlib.pyplot as plt\n", + "import numpy as np\n", "\n", "\n", "def plot_kernel(kernel, **kwargs):\n", diff --git a/docs/tutorials/kernels.ipynb b/docs/tutorials/kernels.ipynb index 9f3f263..e706c1c 100644 --- a/docs/tutorials/kernels.ipynb +++ b/docs/tutorials/kernels.ipynb @@ -48,10 +48,11 @@ "metadata": {}, "outputs": [], "source": [ - "import tinygp\n", "import jax\n", "import jax.numpy as jnp\n", "\n", + "import tinygp\n", + "\n", "\n", "class SpectralMixture(tinygp.kernels.Kernel):\n", " weight: jax.Array\n", @@ -85,8 +86,8 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", "import matplotlib.pyplot as plt\n", + "import numpy as np\n", "\n", "\n", "def build_gp(theta):\n", diff --git a/docs/tutorials/likelihoods.ipynb b/docs/tutorials/likelihoods.ipynb index dd2ff08..a46c2f6 100644 --- a/docs/tutorials/likelihoods.ipynb +++ b/docs/tutorials/likelihoods.ipynb @@ -54,8 +54,8 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", "import matplotlib.pyplot as plt\n", + "import numpy as np\n", "\n", "random = np.random.default_rng(203618)\n", "x = np.linspace(-3, 3, 20)\n", @@ -90,7 +90,8 @@ "import jax.numpy as jnp\n", "import numpyro\n", "import numpyro.distributions as dist\n", - "from tinygp import kernels, GaussianProcess\n", + "\n", + "from tinygp import GaussianProcess, kernels\n", "\n", "jax.config.update(\"jax_enable_x64\", True)\n", "\n", diff --git a/docs/tutorials/means.ipynb b/docs/tutorials/means.ipynb index 8176a00..cd3808e 100644 --- a/docs/tutorials/means.ipynb +++ b/docs/tutorials/means.ipynb @@ -52,10 +52,11 @@ "outputs": [], "source": [ "from functools import partial\n", - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", + "\n", "import jax\n", "import jax.numpy as jnp\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", "\n", "jax.config.update(\"jax_enable_x64\", True)\n", "\n", @@ -172,7 +173,7 @@ "metadata": {}, "outputs": [], "source": [ - "from tinygp import kernels, GaussianProcess\n", + "from tinygp import GaussianProcess, kernels\n", "\n", "\n", "def build_gp(params):\n", diff --git a/docs/tutorials/mixture.ipynb b/docs/tutorials/mixture.ipynb index 0090ba8..73ccdf9 100644 --- a/docs/tutorials/mixture.ipynb +++ b/docs/tutorials/mixture.ipynb @@ -66,9 +66,8 @@ "source": [ "import jax\n", "import jax.numpy as jnp\n", - "\n", - "import numpy as np\n", "import matplotlib.pyplot as plt\n", + "import numpy as np\n", "\n", "from tinygp import GaussianProcess, kernels, transforms\n", "\n", diff --git a/docs/tutorials/modeling.ipynb b/docs/tutorials/modeling.ipynb index ba983ac..71a7f7c 100644 --- a/docs/tutorials/modeling.ipynb +++ b/docs/tutorials/modeling.ipynb @@ -67,8 +67,8 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", "import matplotlib.pyplot as plt\n", + "import numpy as np\n", "\n", "random = np.random.default_rng(42)\n", "\n", @@ -120,15 +120,13 @@ "metadata": {}, "outputs": [], "source": [ - "from tinygp import kernels, GaussianProcess\n", - "\n", + "import flax.linen as nn\n", "import jax\n", "import jax.numpy as jnp\n", - "\n", - "import flax.linen as nn\n", + "import optax\n", "from flax.linen.initializers import zeros\n", "\n", - "import optax\n", + "from tinygp import GaussianProcess, kernels\n", "\n", "\n", "class GPModule(nn.Module):\n", diff --git a/docs/tutorials/multivariate.ipynb b/docs/tutorials/multivariate.ipynb index 6009611..9c15c75 100644 --- a/docs/tutorials/multivariate.ipynb +++ b/docs/tutorials/multivariate.ipynb @@ -74,8 +74,9 @@ "outputs": [], "source": [ "import jax\n", - "import numpy as np\n", "import jax.numpy as jnp\n", + "import numpy as np\n", + "\n", "from tinygp import kernels\n", "\n", "jax.config.update(\"jax_enable_x64\", True)\n", @@ -117,8 +118,8 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", "import matplotlib.pyplot as plt\n", + "import numpy as np\n", "\n", "random = np.random.default_rng(48392)\n", "X = random.uniform(-5, 5, (100, 2))\n", @@ -163,6 +164,7 @@ "outputs": [], "source": [ "import jaxopt\n", + "\n", "from tinygp import GaussianProcess, kernels, transforms\n", "\n", "\n", diff --git a/docs/tutorials/quasisep-custom.ipynb b/docs/tutorials/quasisep-custom.ipynb index d1e7629..465cd3e 100644 --- a/docs/tutorials/quasisep-custom.ipynb +++ b/docs/tutorials/quasisep-custom.ipynb @@ -131,8 +131,8 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", "import matplotlib.pyplot as plt\n", + "import numpy as np\n", "\n", "random = np.random.default_rng(394)\n", "t = np.sort(random.uniform(0, 10, 700))\n", diff --git a/docs/tutorials/quasisep.ipynb b/docs/tutorials/quasisep.ipynb index 8028317..73fbaa4 100644 --- a/docs/tutorials/quasisep.ipynb +++ b/docs/tutorials/quasisep.ipynb @@ -57,8 +57,8 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", "import matplotlib.pyplot as plt\n", + "import numpy as np\n", "\n", "random = np.random.default_rng(42)\n", "\n", @@ -108,7 +108,7 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "from tinygp import kernels, GaussianProcess\n", + "from tinygp import GaussianProcess, kernels\n", "\n", "jax.config.update(\"jax_enable_x64\", True)\n", "\n", diff --git a/docs/tutorials/quickstart.ipynb b/docs/tutorials/quickstart.ipynb index a35671d..244f4f5 100644 --- a/docs/tutorials/quickstart.ipynb +++ b/docs/tutorials/quickstart.ipynb @@ -48,8 +48,8 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", "import matplotlib.pyplot as plt\n", + "import numpy as np\n", "from statsmodels.datasets import co2\n", "\n", "data = co2.load_pandas().data\n", @@ -102,8 +102,7 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "from tinygp import kernels, GaussianProcess\n", - "\n", + "from tinygp import GaussianProcess, kernels\n", "\n", "jax.config.update(\"jax_enable_x64\", True)\n", "\n", diff --git a/docs/tutorials/transforms.ipynb b/docs/tutorials/transforms.ipynb index 220329a..6942d97 100644 --- a/docs/tutorials/transforms.ipynb +++ b/docs/tutorials/transforms.ipynb @@ -61,8 +61,8 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", "import matplotlib.pyplot as plt\n", + "import numpy as np\n", "\n", "random = np.random.default_rng(567)\n", "\n", @@ -99,12 +99,13 @@ "metadata": {}, "outputs": [], "source": [ + "import flax.linen as nn\n", "import jax\n", - "import optax\n", "import jax.numpy as jnp\n", - "import flax.linen as nn\n", + "import optax\n", "from flax.linen.initializers import zeros\n", - "from tinygp import kernels, transforms, GaussianProcess" + "\n", + "from tinygp import GaussianProcess, kernels, transforms" ] }, { diff --git a/pyproject.toml b/pyproject.toml index 6988f04..8a44938 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,8 @@ line-length = 88 [tool.ruff] target-version = "py39" line-length = 88 + +[tool.ruff.lint] select = ["F", "I", "E", "W", "YTT", "B", "Q", "PLE", "PLR", "PLW", "UP"] ignore = [ "E741", # Allow ambiguous variable names @@ -60,7 +62,15 @@ ignore = [ ] exclude = [] -[tool.ruff.isort] +[tool.ruff.lint.per-file-ignores] +"docs/tutorials/*.ipynb" = [ + "B007", # Loop variable is not used + "E501", # Line too long + "E731", # Do not assign a lambda expression + "F401", # Unused imports +] + +[tool.ruff.lint.isort] known-first-party = ["tinygp"] combine-as-imports = true