diff --git a/itemloaders/__init__.py b/itemloaders/__init__.py index aaa5c2b..74a9970 100644 --- a/itemloaders/__init__.py +++ b/itemloaders/__init__.py @@ -4,24 +4,43 @@ See documentation in docs/topics/loaders.rst """ +from __future__ import annotations + from contextlib import suppress +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + MutableMapping, + Optional, + Pattern, + Union, +) from itemadapter import ItemAdapter +from parsel import Selector from parsel.utils import extract_regex, flatten from itemloaders.common import wrap_loader_context from itemloaders.processors import Identity from itemloaders.utils import arg_to_iter +if TYPE_CHECKING: + # typing.Self requires Python 3.11 + from typing_extensions import Self + -def unbound_method(method): +def unbound_method(method: Callable[..., Any]) -> Callable[..., Any]: """ Allow to use single-argument functions as input or output processors (no need to define an unused first 'self' argument) """ with suppress(AttributeError): if "." not in method.__qualname__: - return method.__func__ + return method.__func__ # type: ignore[attr-defined, no-any-return] return method @@ -96,40 +115,46 @@ class Product: .. _parsel: https://parsel.readthedocs.io/en/latest/ """ - default_item_class = dict - default_input_processor = Identity() - default_output_processor = Identity() - - def __init__(self, item=None, selector=None, parent=None, **context): - self.selector = selector + default_item_class: type = dict + default_input_processor: Callable[..., Any] = Identity() + default_output_processor: Callable[..., Any] = Identity() + + def __init__( + self, + item: Any = None, + selector: Optional[Selector] = None, + parent: Optional[ItemLoader] = None, + **context: Any, + ): + self.selector: Optional[Selector] = selector context.update(selector=selector) if item is None: item = self.default_item_class() self._local_item = item context["item"] = item - self.context = context - self.parent = parent - self._local_values = {} + self.context: MutableMapping[str, Any] = context + self.parent: Optional[ItemLoader] = parent + self._local_values: Dict[str, List[Any]] = {} # values from initial item for field_name, value in ItemAdapter(item).items(): self._values.setdefault(field_name, []) self._values[field_name] += arg_to_iter(value) @property - def _values(self): + def _values(self) -> Dict[str, List[Any]]: if self.parent is not None: return self.parent._values else: return self._local_values @property - def item(self): + def item(self) -> Any: if self.parent is not None: return self.parent.item else: return self._local_item - def nested_xpath(self, xpath, **context): + def nested_xpath(self, xpath: str, **context: Any) -> Self: """ Create a nested loader with an xpath selector. The supplied selector is applied relative to selector associated @@ -137,12 +162,14 @@ def nested_xpath(self, xpath, **context): with the parent :class:`ItemLoader` so calls to :meth:`add_xpath`, :meth:`add_value`, :meth:`replace_value`, etc. will behave as expected. """ + self._check_selector_method() + assert self.selector selector = self.selector.xpath(xpath) context.update(selector=selector) subloader = self.__class__(item=self.item, parent=self, **context) return subloader - def nested_css(self, css, **context): + def nested_css(self, css: str, **context: Any) -> Self: """ Create a nested loader with a css selector. The supplied selector is applied relative to selector associated @@ -150,12 +177,21 @@ def nested_css(self, css, **context): with the parent :class:`ItemLoader` so calls to :meth:`add_xpath`, :meth:`add_value`, :meth:`replace_value`, etc. will behave as expected. """ + self._check_selector_method() + assert self.selector selector = self.selector.css(css) context.update(selector=selector) subloader = self.__class__(item=self.item, parent=self, **context) return subloader - def add_value(self, field_name, value, *processors, re=None, **kw): + def add_value( + self, + field_name: Optional[str], + value: Any, + *processors: Callable[..., Any], + re: Union[str, Pattern[str], None] = None, + **kw: Any, + ) -> None: """ Process and then add the given ``value`` for the given field. @@ -186,7 +222,14 @@ def add_value(self, field_name, value, *processors, re=None, **kw): else: self._add_value(field_name, value) - def replace_value(self, field_name, value, *processors, re=None, **kw): + def replace_value( + self, + field_name: Optional[str], + value: Any, + *processors: Callable[..., Any], + re: Union[str, Pattern[str], None] = None, + **kw: Any, + ) -> None: """ Similar to :meth:`add_value` but replaces the collected data with the new value instead of adding it. @@ -200,18 +243,24 @@ def replace_value(self, field_name, value, *processors, re=None, **kw): else: self._replace_value(field_name, value) - def _add_value(self, field_name, value): + def _add_value(self, field_name: str, value: Any) -> None: value = arg_to_iter(value) processed_value = self._process_input_value(field_name, value) if processed_value: self._values.setdefault(field_name, []) self._values[field_name] += arg_to_iter(processed_value) - def _replace_value(self, field_name, value): + def _replace_value(self, field_name: str, value: Any) -> None: self._values.pop(field_name, None) self._add_value(field_name, value) - def get_value(self, value, *processors, re=None, **kw): + def get_value( + self, + value: Any, + *processors: Callable[..., Any], + re: Union[str, Pattern[str], None] = None, + **kw: Any, + ) -> Any: """ Process the given ``value`` by the given ``processors`` and keyword arguments. @@ -221,7 +270,7 @@ def get_value(self, value, *processors, re=None, **kw): :param re: a regular expression to use for extracting data from the given value using :func:`~parsel.utils.extract_regex` method, applied before processors - :type re: str or typing.Pattern + :type re: str or typing.Pattern[str] Examples: @@ -249,7 +298,7 @@ def get_value(self, value, *processors, re=None, **kw): ) from e return value - def load_item(self): + def load_item(self) -> Any: """ Populate the item with the data collected so far, and return it. The data collected is first passed through the :ref:`output processors @@ -263,7 +312,7 @@ def load_item(self): return adapter.item - def get_output_value(self, field_name): + def get_output_value(self, field_name: str) -> Any: """ Return the collected values parsed using the output processor, for the given field. This method doesn't populate or modify the item at all. @@ -279,11 +328,11 @@ def get_output_value(self, field_name): % (field_name, value, type(e).__name__, str(e)) ) from e - def get_collected_values(self, field_name): + def get_collected_values(self, field_name: str) -> List[Any]: """Return the collected values for the given field.""" return self._values.get(field_name, []) - def get_input_processor(self, field_name): + def get_input_processor(self, field_name: str) -> Callable[..., Any]: proc = getattr(self, "%s_in" % field_name, None) if not proc: proc = self._get_item_field_attr( @@ -291,7 +340,7 @@ def get_input_processor(self, field_name): ) return unbound_method(proc) - def get_output_processor(self, field_name): + def get_output_processor(self, field_name: str) -> Callable[..., Any]: proc = getattr(self, "%s_out" % field_name, None) if not proc: proc = self._get_item_field_attr( @@ -299,11 +348,13 @@ def get_output_processor(self, field_name): ) return unbound_method(proc) - def _get_item_field_attr(self, field_name, key, default=None): + def _get_item_field_attr( + self, field_name: str, key: Any, default: Any = None + ) -> Any: field_meta = ItemAdapter(self.item).get_field_meta(field_name) return field_meta.get(key, default) - def _process_input_value(self, field_name, value): + def _process_input_value(self, field_name: str, value: Any) -> Any: proc = self.get_input_processor(field_name) _proc = proc proc = wrap_loader_context(proc, self.context) @@ -322,14 +373,21 @@ def _process_input_value(self, field_name, value): ) ) from e - def _check_selector_method(self): + def _check_selector_method(self) -> None: if self.selector is None: raise RuntimeError( "To use XPath or CSS selectors, %s " "must be instantiated with a selector" % self.__class__.__name__ ) - def add_xpath(self, field_name, xpath, *processors, re=None, **kw): + def add_xpath( + self, + field_name: Optional[str], + xpath: Union[str, Iterable[str]], + *processors: Callable[..., Any], + re: Union[str, Pattern[str], None] = None, + **kw: Any, + ) -> None: """ Similar to :meth:`ItemLoader.add_value` but receives an XPath instead of a value, which is used to extract a list of strings from the @@ -351,14 +409,27 @@ def add_xpath(self, field_name, xpath, *processors, re=None, **kw): values = self._get_xpathvalues(xpath, **kw) self.add_value(field_name, values, *processors, re=re, **kw) - def replace_xpath(self, field_name, xpath, *processors, re=None, **kw): + def replace_xpath( + self, + field_name: Optional[str], + xpath: Union[str, Iterable[str]], + *processors: Callable[..., Any], + re: Union[str, Pattern[str], None] = None, + **kw: Any, + ) -> None: """ Similar to :meth:`add_xpath` but replaces collected data instead of adding it. """ values = self._get_xpathvalues(xpath, **kw) self.replace_value(field_name, values, *processors, re=re, **kw) - def get_xpath(self, xpath, *processors, re=None, **kw): + def get_xpath( + self, + xpath: Union[str, Iterable[str]], + *processors: Callable[..., Any], + re: Union[str, Pattern[str], None] = None, + **kw: Any, + ) -> Any: """ Similar to :meth:`ItemLoader.get_value` but receives an XPath instead of a value, which is used to extract a list of unicode strings from the @@ -369,7 +440,7 @@ def get_xpath(self, xpath, *processors, re=None, **kw): :param re: a regular expression to use for extracting data from the selected XPath region - :type re: str or typing.Pattern + :type re: str or typing.Pattern[str] Examples:: @@ -382,12 +453,22 @@ def get_xpath(self, xpath, *processors, re=None, **kw): values = self._get_xpathvalues(xpath, **kw) return self.get_value(values, *processors, re=re, **kw) - def _get_xpathvalues(self, xpaths, **kw): + def _get_xpathvalues( + self, xpaths: Union[str, Iterable[str]], **kw: Any + ) -> List[Any]: self._check_selector_method() + assert self.selector xpaths = arg_to_iter(xpaths) return flatten(self.selector.xpath(xpath, **kw).getall() for xpath in xpaths) - def add_css(self, field_name, css, *processors, re=None, **kw): + def add_css( + self, + field_name: Optional[str], + css: Union[str, Iterable[str]], + *processors: Callable[..., Any], + re: Union[str, Pattern[str], None] = None, + **kw: Any, + ) -> None: """ Similar to :meth:`ItemLoader.add_value` but receives a CSS selector instead of a value, which is used to extract a list of unicode strings @@ -408,14 +489,27 @@ def add_css(self, field_name, css, *processors, re=None, **kw): values = self._get_cssvalues(css) self.add_value(field_name, values, *processors, re=re, **kw) - def replace_css(self, field_name, css, *processors, re=None, **kw): + def replace_css( + self, + field_name: Optional[str], + css: Union[str, Iterable[str]], + *processors: Callable[..., Any], + re: Union[str, Pattern[str], None] = None, + **kw: Any, + ) -> None: """ Similar to :meth:`add_css` but replaces collected data instead of adding it. """ values = self._get_cssvalues(css) self.replace_value(field_name, values, *processors, re=re, **kw) - def get_css(self, css, *processors, re=None, **kw): + def get_css( + self, + css: Union[str, Iterable[str]], + *processors: Callable[..., Any], + re: Union[str, Pattern[str], None] = None, + **kw: Any, + ) -> Any: """ Similar to :meth:`ItemLoader.get_value` but receives a CSS selector instead of a value, which is used to extract a list of unicode strings @@ -426,7 +520,7 @@ def get_css(self, css, *processors, re=None, **kw): :param re: a regular expression to use for extracting data from the selected CSS region - :type re: str or typing.Pattern + :type re: str or typing.Pattern[str] Examples:: @@ -438,12 +532,20 @@ def get_css(self, css, *processors, re=None, **kw): values = self._get_cssvalues(css) return self.get_value(values, *processors, re=re, **kw) - def _get_cssvalues(self, csss): + def _get_cssvalues(self, csss: Union[str, Iterable[str]]) -> List[Any]: self._check_selector_method() + assert self.selector csss = arg_to_iter(csss) return flatten(self.selector.css(css).getall() for css in csss) - def add_jmes(self, field_name, jmes, *processors, re=None, **kw): + def add_jmes( + self, + field_name: Optional[str], + jmes: str, + *processors: Callable[..., Any], + re: Union[str, Pattern[str], None] = None, + **kw: Any, + ) -> None: """ Similar to :meth:`ItemLoader.add_value` but receives a JMESPath selector instead of a value, which is used to extract a list of unicode strings @@ -464,14 +566,27 @@ def add_jmes(self, field_name, jmes, *processors, re=None, **kw): values = self._get_jmesvalues(jmes) self.add_value(field_name, values, *processors, re=re, **kw) - def replace_jmes(self, field_name, jmes, *processors, re=None, **kw): + def replace_jmes( + self, + field_name: Optional[str], + jmes: Union[str, Iterable[str]], + *processors: Callable[..., Any], + re: Union[str, Pattern[str], None] = None, + **kw: Any, + ) -> None: """ Similar to :meth:`add_jmes` but replaces collected data instead of adding it. """ values = self._get_jmesvalues(jmes) self.replace_value(field_name, values, *processors, re=re, **kw) - def get_jmes(self, jmes, *processors, re=None, **kw): + def get_jmes( + self, + jmes: Union[str, Iterable[str]], + *processors: Callable[..., Any], + re: Union[str, Pattern[str], None] = None, + **kw: Any, + ) -> Any: """ Similar to :meth:`ItemLoader.get_value` but receives a JMESPath selector instead of a value, which is used to extract a list of unicode strings @@ -494,8 +609,9 @@ def get_jmes(self, jmes, *processors, re=None, **kw): values = self._get_jmesvalues(jmes) return self.get_value(values, *processors, re=re, **kw) - def _get_jmesvalues(self, jmess): + def _get_jmesvalues(self, jmess: Union[str, Iterable[str]]) -> List[Any]: self._check_selector_method() + assert self.selector jmess = arg_to_iter(jmess) if not hasattr(self.selector, "jmespath"): raise AttributeError( diff --git a/itemloaders/common.py b/itemloaders/common.py index 6c0b7fa..9fe3d91 100644 --- a/itemloaders/common.py +++ b/itemloaders/common.py @@ -1,11 +1,14 @@ """Common functions used in Item Loaders code""" from functools import partial +from typing import Any, Callable, MutableMapping from itemloaders.utils import get_func_args -def wrap_loader_context(function, context): +def wrap_loader_context( + function: Callable[..., Any], context: MutableMapping[str, Any] +) -> Callable[..., Any]: """Wrap functions that receive loader_context to contain the context "pre-loaded" and expose a interface that receives only one argument """ diff --git a/itemloaders/processors.py b/itemloaders/processors.py index 7bef67c..10387b5 100644 --- a/itemloaders/processors.py +++ b/itemloaders/processors.py @@ -5,6 +5,7 @@ """ from collections import ChainMap +from typing import Any, Callable, Iterable, List, MutableMapping, Optional from itemloaders.common import wrap_loader_context from itemloaders.utils import arg_to_iter @@ -54,19 +55,22 @@ class MapCompose: .. _`parsel selectors`: https://parsel.readthedocs.io/en/latest/parsel.html#parsel.selector.Selector.extract """ # noqa - def __init__(self, *functions, **default_loader_context): + def __init__(self, *functions: Callable[..., Any], **default_loader_context: Any): self.functions = functions self.default_loader_context = default_loader_context - def __call__(self, value, loader_context=None): + def __call__( + self, value: Any, loader_context: Optional[MutableMapping[str, Any]] = None + ) -> Iterable[Any]: values = arg_to_iter(value) + context: MutableMapping[str, Any] if loader_context: context = ChainMap(loader_context, self.default_loader_context) else: context = self.default_loader_context wrapped_funcs = [wrap_loader_context(f, context) for f in self.functions] for func in wrapped_funcs: - next_values = [] + next_values: List[Any] = [] for v in values: try: next_values += arg_to_iter(func(v)) @@ -109,12 +113,15 @@ class Compose: ` attribute. """ - def __init__(self, *functions, **default_loader_context): + def __init__(self, *functions: Callable[..., Any], **default_loader_context: Any): self.functions = functions self.stop_on_none = default_loader_context.get("stop_on_none", True) self.default_loader_context = default_loader_context - def __call__(self, value, loader_context=None): + def __call__( + self, value: Any, loader_context: Optional[MutableMapping[str, Any]] = None + ) -> Any: + context: MutableMapping[str, Any] if loader_context: context = ChainMap(loader_context, self.default_loader_context) else: @@ -148,7 +155,7 @@ class TakeFirst: 'one' """ - def __call__(self, values): + def __call__(self, values: Any) -> Any: for value in values: if value is not None and value != "": return value @@ -168,7 +175,7 @@ class Identity: ['one', 'two', 'three'] """ - def __call__(self, values): + def __call__(self, values: Any) -> Any: return values @@ -198,13 +205,15 @@ class SelectJmes: ['bar'] """ - def __init__(self, json_path): - self.json_path = json_path - import jmespath + def __init__(self, json_path: str): + self.json_path: str = json_path + import jmespath.parser - self.compiled_path = jmespath.compile(self.json_path) + self.compiled_path: jmespath.parser.ParsedResult = jmespath.compile( + self.json_path + ) - def __call__(self, value): + def __call__(self, value: Any) -> Any: """Query value for the jmespath query and return answer :param value: a data structure (dict, list) to extract from :return: Element extracted according to jmespath query @@ -231,8 +240,8 @@ class Join: 'one
two
three' """ - def __init__(self, separator=" "): + def __init__(self, separator: str = " "): self.separator = separator - def __call__(self, values): + def __call__(self, values: Any) -> str: return self.separator.join(values) diff --git a/itemloaders/utils.py b/itemloaders/utils.py index 3103858..91a6556 100644 --- a/itemloaders/utils.py +++ b/itemloaders/utils.py @@ -5,10 +5,10 @@ import inspect from functools import partial -from typing import Generator +from typing import Any, Callable, Generator, Iterable, List -def arg_to_iter(arg): +def arg_to_iter(arg: Any) -> Iterable[Any]: """Return an iterable based on *arg*. If *arg* is a list, a tuple or a generator, it will be returned as is. @@ -25,12 +25,12 @@ def arg_to_iter(arg): return [arg] -def get_func_args(func, stripself=False): +def get_func_args(func: Callable[..., Any], stripself: bool = False) -> List[str]: """Return the argument name list of a callable object""" if not callable(func): raise TypeError(f"func must be callable, got {type(func).__name__!r}") - args = [] + args: List[str] = [] try: sig = inspect.signature(func) except ValueError: diff --git a/setup.cfg b/setup.cfg index 6e8d795..4f96012 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,8 +1,15 @@ [flake8] -ignore = E266, E501, W503 +ignore = E266, E501, E704, W503 max-line-length = 100 select = B,C,E,F,W,T4,B9 exclude = .git,__pycache__,.venv [isort] profile = black + +[mypy] + +[mypy-tests.*] +# Allow test functions to be untyped +allow_untyped_defs = true +check_untyped_defs = true diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_base_loader.py b/tests/test_base_loader.py index c0bf007..e7919f4 100644 --- a/tests/test_base_loader.py +++ b/tests/test_base_loader.py @@ -300,17 +300,17 @@ def test_output_processor_using_classes(self): il.add_value("name", ["mar", "ta"]) self.assertEqual(il.get_output_value("name"), ["Mar", "Ta"]) - class TakeFirstItemLoader(CustomItemLoader): + class TakeFirstItemLoader1(CustomItemLoader): name_out = Join() - il = TakeFirstItemLoader() + il = TakeFirstItemLoader1() il.add_value("name", ["mar", "ta"]) self.assertEqual(il.get_output_value("name"), "Mar Ta") - class TakeFirstItemLoader(CustomItemLoader): + class TakeFirstItemLoader2(CustomItemLoader): name_out = Join("
") - il = TakeFirstItemLoader() + il = TakeFirstItemLoader2() il.add_value("name", ["mar", "ta"]) self.assertEqual(il.get_output_value("name"), "Mar
Ta") diff --git a/tests/test_loader_initialization.py b/tests/test_loader_initialization.py index 0f63253..67b9449 100644 --- a/tests/test_loader_initialization.py +++ b/tests/test_loader_initialization.py @@ -1,12 +1,21 @@ import unittest +from typing import Any, Protocol from itemloaders import ItemLoader +class InitializationTestProtocol(Protocol): + item_class: Any + + def assertEqual(self, first: Any, second: Any, msg: Any = ...) -> None: ... + + def assertIsInstance(self, obj: object, cls: type, msg: Any = None) -> None: ... + + class InitializationTestMixin: - item_class = None + item_class: Any = None - def test_keep_single_value(self): + def test_keep_single_value(self: InitializationTestProtocol) -> None: """Loaded item should contain values from the initial item""" input_item = self.item_class(name="foo") il = ItemLoader(item=input_item) @@ -14,7 +23,7 @@ def test_keep_single_value(self): self.assertIsInstance(loaded_item, self.item_class) self.assertEqual(dict(loaded_item), {"name": ["foo"]}) - def test_keep_list(self): + def test_keep_list(self: InitializationTestProtocol) -> None: """Loaded item should contain values from the initial item""" input_item = self.item_class(name=["foo", "bar"]) il = ItemLoader(item=input_item) @@ -22,7 +31,9 @@ def test_keep_list(self): self.assertIsInstance(loaded_item, self.item_class) self.assertEqual(dict(loaded_item), {"name": ["foo", "bar"]}) - def test_add_value_singlevalue_singlevalue(self): + def test_add_value_singlevalue_singlevalue( + self: InitializationTestProtocol, + ) -> None: """Values added after initialization should be appended""" input_item = self.item_class(name="foo") il = ItemLoader(item=input_item) @@ -31,7 +42,7 @@ def test_add_value_singlevalue_singlevalue(self): self.assertIsInstance(loaded_item, self.item_class) self.assertEqual(dict(loaded_item), {"name": ["foo", "bar"]}) - def test_add_value_singlevalue_list(self): + def test_add_value_singlevalue_list(self: InitializationTestProtocol) -> None: """Values added after initialization should be appended""" input_item = self.item_class(name="foo") il = ItemLoader(item=input_item) @@ -40,7 +51,7 @@ def test_add_value_singlevalue_list(self): self.assertIsInstance(loaded_item, self.item_class) self.assertEqual(dict(loaded_item), {"name": ["foo", "item", "loader"]}) - def test_add_value_list_singlevalue(self): + def test_add_value_list_singlevalue(self: InitializationTestProtocol) -> None: """Values added after initialization should be appended""" input_item = self.item_class(name=["foo", "bar"]) il = ItemLoader(item=input_item) @@ -49,7 +60,7 @@ def test_add_value_list_singlevalue(self): self.assertIsInstance(loaded_item, self.item_class) self.assertEqual(dict(loaded_item), {"name": ["foo", "bar", "qwerty"]}) - def test_add_value_list_list(self): + def test_add_value_list_list(self: InitializationTestProtocol) -> None: """Values added after initialization should be appended""" input_item = self.item_class(name=["foo", "bar"]) il = ItemLoader(item=input_item) @@ -58,7 +69,7 @@ def test_add_value_list_list(self): self.assertIsInstance(loaded_item, self.item_class) self.assertEqual(dict(loaded_item), {"name": ["foo", "bar", "item", "loader"]}) - def test_get_output_value_singlevalue(self): + def test_get_output_value_singlevalue(self: InitializationTestProtocol) -> None: """Getting output value must not remove value from item""" input_item = self.item_class(name="foo") il = ItemLoader(item=input_item) @@ -67,7 +78,7 @@ def test_get_output_value_singlevalue(self): self.assertIsInstance(loaded_item, self.item_class) self.assertEqual(loaded_item, {"name": ["foo"]}) - def test_get_output_value_list(self): + def test_get_output_value_list(self: InitializationTestProtocol) -> None: """Getting output value must not remove value from item""" input_item = self.item_class(name=["foo", "bar"]) il = ItemLoader(item=input_item) @@ -76,13 +87,13 @@ def test_get_output_value_list(self): self.assertIsInstance(loaded_item, self.item_class) self.assertEqual(loaded_item, {"name": ["foo", "bar"]}) - def test_values_single(self): + def test_values_single(self: InitializationTestProtocol) -> None: """Values from initial item must be added to loader._values""" input_item = self.item_class(name="foo") il = ItemLoader(item=input_item) self.assertEqual(il._values.get("name"), ["foo"]) - def test_values_list(self): + def test_values_list(self: InitializationTestProtocol) -> None: """Values from initial item must be added to loader._values""" input_item = self.item_class(name=["foo", "bar"]) il = ItemLoader(item=input_item) diff --git a/tests/test_nested_items.py b/tests/test_nested_items.py index 444431a..fee9913 100644 --- a/tests/test_nested_items.py +++ b/tests/test_nested_items.py @@ -1,4 +1,5 @@ import unittest +from typing import Any from itemloaders import ItemLoader @@ -6,7 +7,7 @@ class NestedItemTest(unittest.TestCase): """Test that adding items as values works as expected.""" - def _test_item(self, item): + def _test_item(self, item: Any) -> None: il = ItemLoader() il.add_value("item_list", item) self.assertEqual(il.load_item(), {"item_list": [item]}) @@ -44,7 +45,8 @@ def test_scrapy_item(self): except ImportError: self.skipTest("Cannot import Field or Item from scrapy") - class TestItem(Item): + # needs py.typed in Scrapy + class TestItem(Item): # type: ignore[misc] foo = Field() self._test_item(TestItem(foo="bar")) diff --git a/tests/test_nested_loader.py b/tests/test_nested_loader.py index 58b9bec..82d24f7 100644 --- a/tests/test_nested_loader.py +++ b/tests/test_nested_loader.py @@ -28,6 +28,7 @@ def test_nested_xpath(self): nl = loader.nested_xpath("//header") nl.add_xpath("name", "div/text()") nl.add_css("name_div", "#id") + assert nl.selector nl.add_value("name_value", nl.selector.xpath('div[@id = "id"]/text()').getall()) self.assertEqual(loader.get_output_value("name"), ["marta"]) @@ -49,6 +50,7 @@ def test_nested_css(self): nl = loader.nested_css("header") nl.add_xpath("name", "div/text()") nl.add_css("name_div", "#id") + assert nl.selector nl.add_value("name_value", nl.selector.xpath('div[@id = "id"]/text()').getall()) self.assertEqual(loader.get_output_value("name"), ["marta"]) diff --git a/tests/test_output_processor.py b/tests/test_output_processor.py index f4aa387..09ef95d 100644 --- a/tests/test_output_processor.py +++ b/tests/test_output_processor.py @@ -1,4 +1,5 @@ import unittest +from typing import Any, Dict from itemloaders import ItemLoader from itemloaders.processors import Compose, Identity, TakeFirst @@ -6,7 +7,7 @@ class TestOutputProcessorDict(unittest.TestCase): def test_output_processor(self): - class TempDict(dict): + class TempDict(Dict[str, Any]): def __init__(self, *args, **kwargs): super(TempDict, self).__init__(self, *args, **kwargs) self.setdefault("temp", 0.3) @@ -28,7 +29,7 @@ class TempLoader(ItemLoader): default_input_processor = Identity() default_output_processor = Compose(TakeFirst()) - item = {} + item: Dict[str, Any] = {} item.setdefault("temp", 0.3) loader = TempLoader(item=item) item = loader.load_item() diff --git a/tests/test_utils_python.py b/tests/test_utils_python.py index 83938ed..94f5ecb 100644 --- a/tests/test_utils_python.py +++ b/tests/test_utils_python.py @@ -2,6 +2,7 @@ import operator import platform import unittest +from typing import Any from itemloaders.utils import get_func_args @@ -18,7 +19,7 @@ def f3(a, b=None, *, c=None): pass class A: - def __init__(self, a, b, c): + def __init__(self, a: Any, b: Any, c: Any): pass def method(self, a, b, c): diff --git a/tox.ini b/tox.ini index b365089..ee32ddb 100644 --- a/tox.ini +++ b/tox.ini @@ -45,3 +45,12 @@ deps = commands = python -m build --sdist twine check dist/* + +[testenv:typing] +basepython = python3 +deps = + mypy==1.10.0 + types-attrs==19.1.0 + types-jmespath==1.0.2.20240106 +commands = + mypy --strict --ignore-missing-imports --implicit-reexport {posargs:itemloaders tests}