diff --git a/spacy/cli/__init__.py b/spacy/cli/__init__.py index 1d402ff0c98..3095778fe22 100644 --- a/spacy/cli/__init__.py +++ b/spacy/cli/__init__.py @@ -1,5 +1,7 @@ from wasabi import msg +# Needed for testing +from . import download as download_module # noqa: F401 from ._util import app, setup_cli # noqa: F401 from .apply import apply # noqa: F401 from .assemble import assemble_cli # noqa: F401 diff --git a/spacy/cli/download.py b/spacy/cli/download.py index 21c777f81fb..4261fb830d9 100644 --- a/spacy/cli/download.py +++ b/spacy/cli/download.py @@ -1,5 +1,6 @@ import sys from typing import Optional, Sequence +from urllib.parse import urljoin import requests import typer @@ -63,6 +64,13 @@ def download( ) pip_args = pip_args + ("--no-deps",) if direct: + # Reject model names with '/', in order to prevent shenanigans. + if "/" in model: + msg.fail( + title="Model download rejected", + text=f"Cannot download model '{model}'. Models are expected to be file names, not URLs or fragments", + exits=True, + ) components = model.split("-") model_name = "".join(components[:-1]) version = components[-1] @@ -153,7 +161,16 @@ def get_latest_version(model: str) -> str: def download_model( filename: str, user_pip_args: Optional[Sequence[str]] = None ) -> None: - download_url = about.__download_url__ + "/" + filename + # Construct the download URL carefully. We need to make sure we don't + # allow relative paths or other shenanigans to trick us into download + # from outside our own repo. + base_url = about.__download_url__ + # urljoin requires that the path ends with /, or the last path part will be dropped + if not base_url.endswith("/"): + base_url = about.__download_url__ + "/" + download_url = urljoin(base_url, filename) + if not download_url.startswith(about.__download_url__): + raise ValueError(f"Download from {filename} rejected. Was it a relative path?") pip_args = list(user_pip_args) if user_pip_args is not None else [] cmd = [sys.executable, "-m", "pip", "install"] + pip_args + [download_url] run_command(cmd) diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py index ff53ed1e1b0..7b729d78f21 100644 --- a/spacy/tests/test_cli.py +++ b/spacy/tests/test_cli.py @@ -12,7 +12,7 @@ import spacy from spacy import about -from spacy.cli import info +from spacy.cli import download_module, info from spacy.cli._util import parse_config_overrides, string_to_list, walk_directory from spacy.cli.apply import apply from spacy.cli.debug_data import ( @@ -1066,3 +1066,15 @@ def test_debug_data_trainable_lemmatizer_not_annotated(): def test_project_api_imports(): from spacy.cli import project_run from spacy.cli.project.run import project_run # noqa: F401, F811 + + +def test_download_rejects_relative_urls(monkeypatch): + """Test that we can't tell spacy download to get an arbitrary model by using a + relative path in the filename""" + + monkeypatch.setattr(download_module, "run_command", lambda cmd: None) + + # Check that normal download works + download_module.download("en_core_web_sm-3.7.1", direct=True) + with pytest.raises(SystemExit): + download_module.download("../en_core_web_sm-3.7.1", direct=True)