Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor pytest_collection_modify hook #107

Merged
merged 4 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion scpdt/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ class DTConfig:
NameErrors. Set to True if you want to see these, or if your test
is actually expected to raise NameErrors.
Default is False.
pytest_extra_skips : list
A list of names/modules to skip when run under pytest plugin. Ignored
otherwise.

"""
def __init__(self, *, # DTChecker configuration
Expand All @@ -88,6 +91,7 @@ def __init__(self, *, # DTChecker configuration
# Obscure switches
parse_namedtuples=True, # Checker
nameerror_after_exception=False, # Runner
pytest_extra_skips=None, # plugin/collection
):
### DTChecker configuration ###
# The namespace to run examples in
Expand Down Expand Up @@ -141,7 +145,8 @@ def __init__(self, *, # DTChecker configuration
'set_title', 'imshow', 'plt.show', '.axis(', '.plot(',
'.bar(', '.title', '.ylabel', '.xlabel', 'set_ylim', 'set_xlim',
'# reformatted', '.set_xlabel(', '.set_ylabel(', '.set_zlabel(',
'.set(xlim=', '.set(ylim=', '.set(xlabel=', '.set(ylabel=', '.xlim('}
'.set(xlim=', '.set(ylim=', '.set(xlabel=', '.set(ylabel=', '.xlim('
'ax.set('}
self.stopwords = stopwords

if pseudocode is None:
Expand Down Expand Up @@ -170,6 +175,11 @@ def __init__(self, *, # DTChecker configuration
self.parse_namedtuples = parse_namedtuples
self.nameerror_after_exception = nameerror_after_exception

#### pytest plugin additional switches
if pytest_extra_skips is None:
pytest_extra_skips = []
self.pytest_extra_skips = pytest_extra_skips


def try_convert_namedtuple(got):
# suppose that "got" is smth like MoodResult(statistic=10, pvalue=0.1).
Expand Down
99 changes: 37 additions & 62 deletions scpdt/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,81 +57,55 @@ def pytest_ignore_collect(collection_path, config):
path_str = str(collection_path)
if "tests" in path_str or "test_" in path_str:
return True


def pytest_collection_modifyitems(config, items):
"""
This hook is executed after test collection and allows you to modify the list of collected items.

The function removes duplicate Doctest items.
The function removes
- duplicate Doctest items (e.g., scipy.stats.norm and scipy.stats.distributions.norm)
- Doctest items from underscored or otherwise private modules (e.g., scipy.special._precompute)

Doctest items are collected from all public modules, including the __all__ attribute in __init__.py.
This may lead to Doctest items being collected and tested more than once.
We therefore need to remove the duplicate items by creating a new list with only unique items.
Note that this functions cooperates with and cleans up after `DTModule.collect`, which does the
bulk of the collection work.
"""
# XXX: The logic in this function can probably be folded into DTModule.collect.
# I (E.B.) quickly tried it and it does not seem to just work. Apparently something
# pytest-y runs in between DTModule.collect and this hook (should that something
# be the proper home for all collection?)

if config.getoption("--doctest-modules"):
seen_test_names = set()
unique_items = []

for item in items:
# Extract the item name, e.g., 'gauss_spline'
# Example item: <DoctestItem scipy.signal._bsplines.gauss_spline>
item_name = str(item).split('.')[-1].strip('>')

# In case the preceding string represents a function or a class,
# We need to keep the object name as both items represent different functions
# eg: <DoctestItem scipy.signal._ltisys.bode>
# <DoctestItem scipy.signal._ltisys.lti.bode>
obj_name = str(item).split('.')[-2]

# Extract the module path from the item's dtest attribute
# Example dtest: <DocTest scipy.signal.__init__.gauss_spline from /scipy/build-install/lib/python3.10/site-packages/scipy/signal/_bsplines.py:226 (5 examples)>
dtest = item.dtest
path = str(dtest).split(' ')[3].split(':')[0]

# Import the module to check if the object name is an attribute of the module
try:
module = import_path(
path,
root=config.rootpath,
mode=config.getoption("importmode"),
)
except ImportError:
module = None

# Combine the module path, object name (if it exists) and item name to create a unique identifier
if module is not None and obj_name != '__init__' and hasattr(module, obj_name) and callable(getattr(module, obj_name)) and obj_name != item_name:
unique_test_name = f"{path}/{obj_name}.{item_name}"
else:
unique_test_name = f"{path}/{item_name}"

# Check if the test name is unique and add it to the unique_items list if it is
if unique_test_name not in seen_test_names:
seen_test_names.add(unique_test_name)
assert isinstance(item.parent, DTModule)

# objects are collected twice: from their public module + from the impl module
# e.g. for `levy_stable` we have
# (Pdb) p item.name, item.parent.name
# ('scipy.stats.levy_stable', 'build-install/lib/python3.10/site-packages/scipy/stats/__init__.py')
# ('scipy.stats.distributions.levy_stable', 'distributions.py')
# so we filter out the second occurence
#
# There are two options:
# - either the impl module has a leading underscore, or
# - it needs to be explicitly listed in 'extra_skips' config key
#
# Note that the last part cannot be automated: scipy.cluster.vq is public, but
# scipy.stats.distributions is not
extra_skips = config.dt_config.pytest_extra_skips

parent_full_name = item.parent.module.__name__
is_public = "._" not in parent_full_name
is_duplicate = parent_full_name in extra_skips or item.name in extra_skips

if is_public and not is_duplicate:
unique_items.append(item)

# Replace the original list of test items with the unique ones
items[:] = unique_items

# Generate a log of the unique items to be doctested
# Extract the DoctestItem name
for item in items:
dtest = item.dtest
path = str(dtest).split(' ')[3].split(':')[0]

# Import the module being doctested
try:
module = import_path(
path,
root=config.rootpath,
mode=config.getoption("importmode"),
)
except ImportError:
module = None

# Use the module and item name to generate a log entry
generate_log(module, item.name)


def copy_local_files(local_resources, destination_dir):
"""
Expand Down Expand Up @@ -202,11 +176,12 @@ def collect(self):
optionflags=optionflags,
checker=DTChecker(config=self.config.dt_config)
)

try:
# We utilize scpdt's `find_doctests` function to discover doctests in public, non-deprecated objects in the module
# NB: additional postprocessing in pytest_collection_modifyitems
for test in find_doctests(module, strategy="api", name=module.__name__, config=dt_config):
if test.examples: # skip empty doctests
# if test.examples: # skip empty doctests # FIXME: put this back (simplifies comparing the logs)
yield doctest.DoctestItem.from_parent(
self, name=test.name, runner=runner, dtest=test
)
Expand Down
Loading