Skip to content

Commit

Permalink
refactor: move test harness code out of Step into TestStep
Browse files Browse the repository at this point in the history
also:
 - document the methods in the wizard's Tour class
 - remove the unused child_index arg to add_step
  • Loading branch information
joanise committed Sep 20, 2024
1 parent 696a85d commit 832a2ba
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 26 deletions.
42 changes: 31 additions & 11 deletions everyvoice/tests/test_wizard.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,36 @@ def find_step(name: Enum, steps: Sequence[Step | list[Step]]):
raise IndexError(f"Step {name} not found.") # pragma: no cover


class TestStep(Step):
"""A Step subclass that allows specified the effect, prompt and validate
methods in the constructor."""

def __init__(
self,
name: str,
parent=None,
prompt_method: Optional[Callable] = None,
validate_method: Optional[Callable] = None,
effect_method: Optional[Callable] = None,
):
super().__init__(name, parent=parent)
if prompt_method:
self.prompt = prompt_method # type: ignore[method-assign]
if validate_method:
self.validate = validate_method # type: ignore[method-assign]
if effect_method:
self.effect = effect_method # type: ignore[method-assign]


class WizardTest(TestCase):
"""Basic test for the configuration wizard"""

data_dir = Path(__file__).parent / "data"

def test_implementation_missing(self):
nothing_step = Step(name="Dummy Step")
no_validate_step = Step(name="Dummy Step", prompt_method=lambda: "test")
no_prompt_step = Step(name="Dummy Step", validate_method=lambda: True)
no_validate_step = TestStep(name="Dummy Step", prompt_method=lambda: "test")
no_prompt_step = TestStep(name="Dummy Step", validate_method=lambda: True)
for step in [nothing_step, no_validate_step, no_prompt_step]:
with self.assertRaises(NotImplementedError):
step.run()
Expand Down Expand Up @@ -159,7 +180,7 @@ def test_config_format_effect(self):
)

def test_access_response(self):
root_step = Step(
root_step = TestStep(
name="Dummy Step",
prompt_method=lambda: "foo",
validate_method=lambda x: True,
Expand All @@ -170,16 +191,16 @@ def validate(self, x):
if self.parent.response + x == "foobar":
return True

second_step = Step(
second_step = TestStep(
name="Dummy Step 2", prompt_method=lambda: "bar", parent=root_step
)
second_step.validate = MethodType(validate, second_step)
for i, leaf in enumerate(RenderTree(root_step)):
for i, node in enumerate(PreOrderIter(root_step)):
if i != 0:
self.assertEqual(second_step.parent.response, "foo")
self.assertTrue(leaf[2].validate("bar"))
self.assertFalse(leaf[2].validate("foo"))
leaf[2].run()
self.assertTrue(node.validate("bar"))
self.assertFalse(node.validate("foo"))
node.run()

def test_main_tour(self):
tour = get_main_wizard_tour()
Expand Down Expand Up @@ -463,7 +484,7 @@ def test_whitespace_always_collapsed(self):
know_speaker_step.run()

add_speaker_step = know_speaker_step.children[0]
with patch_input("default"):
with patch_input("default"), capture_stdout():
add_speaker_step.run()

language_step = find_step(SN.data_has_language_value_step, tour.steps)
Expand Down Expand Up @@ -559,7 +580,7 @@ def test_dataset_subtour(self):

add_speaker_step = know_speaker_step.children[0]
children_before = len(add_speaker_step.children)
with patch_input("default"):
with patch_input("default"), capture_stdout():
add_speaker_step.run()
self.assertEqual(len(add_speaker_step.children), children_before)

Expand Down Expand Up @@ -1724,7 +1745,6 @@ def setUp(self):
SN.wavs_dir_step.value: "Common-Voice",
SN.symbol_set_step.value: {
"characters": [
" ",
"A",
"D",
"E",
Expand Down
29 changes: 15 additions & 14 deletions everyvoice/wizard/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""The main module for the wizard package."""

import sys
from enum import Enum
from typing import Optional, Sequence
Expand Down Expand Up @@ -64,11 +66,7 @@ def __init__(
self,
name: None | Enum | str = None,
default=None,
prompt_method=None,
validate_method=None,
effect_method=None,
parent=None,
children=None,
state_subset=None,
):
if name is None:
Expand All @@ -82,14 +80,6 @@ def __init__(
self.state: Optional[State] = None
# tour will be added when the Step is added to a Tour
self.tour: Optional[Tour] = None
if effect_method is not None:
self.effect = effect_method # type: ignore[method-assign]
if prompt_method is not None:
self.prompt = prompt_method # type: ignore[method-assign]
if validate_method is not None:
self.validate = validate_method # type: ignore[method-assign]
if children:
self.children = children
self._validation_failures = 0

def __repr__(self) -> str:
Expand Down Expand Up @@ -138,6 +128,7 @@ def __init__(self, name: str, steps: list[Step], state: Optional[State] = None):
self.add_steps(steps, self.root)

def determine_state(self, step: Step, state: State):
"""Determines the state to use for the step based on the state subset"""
if step.state_subset is not None:
if step.state_subset not in state:
state[step.state_subset] = State()
Expand All @@ -159,18 +150,28 @@ def add_steps(self, steps: Sequence[Step | list[Step]], parent: Step):
else:
self.add_step(item, parent)

def add_step(self, step: Step, parent: Step, child_index=0):
def add_step(self, step: Step, parent: Step):
"""Insert a step in the specified position in the tour.
Args:
step: The step to add
parent: The parent to add the step to
"""
self.determine_state(step, self.state)
step.tour = self
children = list(parent.children)
children.insert(child_index, step)
children.insert(0, step)
parent.children = children

def run(self):
"""Run the tour by traversing the tree depth-first"""
for _, _, node in RenderTree(self.root):
self.visualize()
print(f"Running {node.name}")
node.run()

def visualize(self):
"""Display the tree structure of the tour on stdout"""
for pre, _, node in RenderTree(self.root):
treestr = f"{pre}{node.name}"
print(treestr.ljust(8))
Expand Down
2 changes: 1 addition & 1 deletion everyvoice/wizard/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def effect(self):
# so this should be fixed by https://github.com/EveryVoiceTTS/EveryVoice/issues/359
if dataset_state.get(StepNames.text_processing_step):
global_cleaners += [
TextProcessingStep().process_lookup[x]["fn"]
TextProcessingStep.process_lookup[x]["fn"]
for x in dataset_state[StepNames.text_processing_step]
]
# Gather Symbols for Text Configuration
Expand Down

0 comments on commit 832a2ba

Please sign in to comment.