Skip to content

Commit

Permalink
Merge pull request #98 from roedoejet/dev.test-tours
Browse files Browse the repository at this point in the history
facility to test tours, plus give each step a default name
  • Loading branch information
joanise authored Sep 20, 2023
2 parents 5973114 + 5964fab commit f824660
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 101 deletions.
6 changes: 2 additions & 4 deletions everyvoice/tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,13 @@ def test_read_filelist(self):
self.assertNotIn("speaker", self.filelist[0].keys())

def test_process_audio_for_alignment(self):
self.config = EveryVoiceConfig()
config = EveryVoiceConfig()
for entry in self.filelist[1:]:
# This just applies the SOX effects
audio, sr = self.preprocessor.process_audio(
self.wavs_dir / (entry["filename"] + ".wav"),
use_effects=True,
sox_effects=self.config.aligner.preprocessing.source_data[
0
].sox_effects,
sox_effects=config.aligner.preprocessing.source_data[0].sox_effects,
)
self.assertEqual(sr, 22050)
self.assertEqual(audio.dtype, float32)
Expand Down
135 changes: 89 additions & 46 deletions everyvoice/tests/test_wizard.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import string
import tempfile
from enum import Enum
from pathlib import Path
from types import MethodType
from typing import Sequence
Expand All @@ -22,7 +23,9 @@
patch_logger,
patch_menu_prompt,
)
from everyvoice.wizard import Step, StepNames, Tour, basic, dataset, prompts, validators
from everyvoice.wizard import Step
from everyvoice.wizard import StepNames as SN
from everyvoice.wizard import Tour, basic, dataset, prompts, validators


class WizardTest(TestCase):
Expand Down Expand Up @@ -51,18 +54,16 @@ def test_config_format_effect(self):
self.assertTrue(config_step.validate("json"))
with tempfile.TemporaryDirectory() as tmpdirname:
config_step.state = {}
config_step.state[StepNames.output_step.value] = tmpdirname
config_step.state[StepNames.name_step.value] = config_step.name
config_step.state[SN.output_step.value] = tmpdirname
config_step.state[SN.name_step.value] = config_step.name
config_step.state["dataset_test"] = {}
config_step.state["dataset_test"][
StepNames.symbol_set_step.value
] = Symbols(symbol_set=string.ascii_letters)
config_step.state["dataset_test"][StepNames.wavs_dir_step.value] = (
config_step.state["dataset_test"][SN.symbol_set_step.value] = Symbols(
symbol_set=string.ascii_letters
)
config_step.state["dataset_test"][SN.wavs_dir_step.value] = (
Path(tmpdirname) / "test"
)
config_step.state["dataset_test"][
StepNames.dataset_name_step.value
] = "test"
config_step.state["dataset_test"][SN.dataset_name_step.value] = "test"
config_step.state["dataset_test"]["filelist_data"] = [
{"basename": "0001", "text": "hello"},
{"basename": "0002", "text": "hello", None: "test"},
Expand Down Expand Up @@ -142,8 +143,8 @@ def test_output_path_step(self):
tour = Tour(
"testing",
[
basic.NameStep(StepNames.name_step.value),
basic.OutputPathStep(StepNames.output_step.value),
basic.NameStep(),
basic.OutputPathStep(),
],
)

Expand All @@ -157,15 +158,14 @@ def test_output_path_step(self):
with tempfile.TemporaryDirectory() as tmpdirname:
file_path = os.path.join(tmpdirname, "exits-as-file")
# Bad case 1: output dir exists and is a file
with open(file_path, "w") as f:
with open(file_path, "w", encoding="utf8") as f:
f.write("blah")
print("blah", file=f)
with patch_logger(basic) as logger, self.assertLogs(logger):
self.assertFalse(step.validate(file_path))

# Bad case 2: file called the same as the dataset exists in the output dir
dataset_file = os.path.join(tmpdirname, "myname")
with open(dataset_file, "w") as f:
with open(dataset_file, "w", encoding="utf8") as f:
f.write("blah")
with patch_logger(basic) as logger, self.assertLogs(logger):
self.assertFalse(step.validate(tmpdirname))
Expand All @@ -182,9 +182,7 @@ def test_output_path_step(self):

def test_more_data_step(self):
"""Exercise giving an invalid response and a yes response to more data."""
tour = Tour(
"testing", [basic.MoreDatasetsStep(name=StepNames.more_datasets_step.value)]
)
tour = Tour("testing", [basic.MoreDatasetsStep()])
step = tour.steps[0]
self.assertFalse(step.validate("foo"))
self.assertTrue(step.validate("yes"))
Expand All @@ -199,7 +197,7 @@ def test_more_data_step(self):
self.assertGreater(len(step.children), 5)

def test_dataset_name(self):
step = dataset.DatasetNameStep("")
step = dataset.DatasetNameStep()
with monkeypatch(builtins, "input", Say(("", "bad/name", "good-name"), True)):
with patch_logger(dataset) as logger, self.assertLogs(logger) as logs:
step.run()
Expand All @@ -215,8 +213,8 @@ def test_wavs_dir(self):

has_wavs_dir = os.path.join(tmpdirname, "there-are-wavs-here")
os.mkdir(has_wavs_dir)
with open(os.path.join(has_wavs_dir, "foo.wav"), "w") as f:
f.write("A fantastic sounding clip! (or not...)")
with open(os.path.join(has_wavs_dir, "foo.wav"), "wb") as f:
f.write(b"A fantastic sounding clip! (or not...)")

step = dataset.WavsDirStep("")
with monkeypatch(
Expand Down Expand Up @@ -253,27 +251,26 @@ def test_sample_rate_config(self):
self.assertEqual(step.response, 512)

def test_dataset_subtour(self):
def find_step(name: str, steps: Sequence[Step]):
def find_step(name: Enum, steps: Sequence[Step]):
for s in steps:
if s.name == name:
if s.name == name.value:
return s
raise IndexError(f"Step {name} not found.") # pragma: no cover

tour = Tour("unit testing", steps=dataset.return_dataset_steps())
filelist = str(self.data_dir / "unit-test-case1.psv")

filelist_step = find_step(StepNames.filelist_step.value, tour.steps)
with monkeypatch(filelist_step, "prompt", Say(filelist)):
filelist_step = find_step(SN.filelist_step, tour.steps)
monkey = monkeypatch(filelist_step, "prompt", Say(filelist))
with monkey:
filelist_step.run()
format_step = find_step(StepNames.filelist_format_step.value, tour.steps)
format_step = find_step(SN.filelist_format_step, tour.steps)
with patch_menu_prompt(0): # 0 is "psv"
format_step.run()
self.assertIsInstance(format_step.children[0], dataset.HeaderStep)
self.assertEqual(
format_step.children[0].name, StepNames.basename_header_step.value
)
self.assertEqual(format_step.children[0].name, SN.basename_header_step.value)
self.assertIsInstance(format_step.children[1], dataset.HeaderStep)
self.assertEqual(format_step.children[1].name, StepNames.text_header_step.value)
self.assertEqual(format_step.children[1].name, SN.text_header_step.value)

step = format_step.children[0]
with patch_menu_prompt(1): # 1 is second column
Expand All @@ -287,17 +284,13 @@ def find_step(name: str, steps: Sequence[Step]):
# print(step.state["filelist_headers"])
self.assertEqual(step.state["filelist_headers"][2], "text")

speaker_step = find_step(
StepNames.data_has_speaker_value_step.value, tour.steps
)
speaker_step = find_step(SN.data_has_speaker_value_step, tour.steps)
children_before = len(speaker_step.children)
with patch_menu_prompt(1): # 1 is "no"
speaker_step.run()
self.assertEqual(len(speaker_step.children), children_before)

language_step = find_step(
StepNames.data_has_language_value_step.value, tour.steps
)
language_step = find_step(SN.data_has_language_value_step, tour.steps)
children_before = len(language_step.children)
with patch_menu_prompt(1): # 1 is "no"
language_step.run()
Expand All @@ -314,9 +307,7 @@ def find_step(name: str, steps: Sequence[Step]):
["unknown_0", "basename", "text", "unknown_3"],
)

text_processing_step = find_step(
StepNames.text_processing_step.value, tour.steps
)
text_processing_step = find_step(SN.text_processing_step, tour.steps)
# 0 is lowercase, 1 is NFC Normalization, select both
with monkeypatch(dataset, "tqdm", lambda seq, desc: seq):
with patch_menu_prompt([0, 1]):
Expand All @@ -327,7 +318,7 @@ def find_step(name: str, steps: Sequence[Step]):
"cased \t nfd: éàê nfc: éàê", # the "nfd: éàê" bit here is now NFC
)

sox_effects_step = find_step(StepNames.sox_effects_step.value, tour.steps)
sox_effects_step = find_step(SN.sox_effects_step, tour.steps)
# 0 is resample to 22050 kHz, 2 is remove silence at start
with patch_menu_prompt([0, 2]):
sox_effects_step.run()
Expand All @@ -337,7 +328,7 @@ def find_step(name: str, steps: Sequence[Step]):
[["channel", "1"], ["rate", "22050"], ["silence", "1", "0.1", "1.0%"]],
)

symbol_set_step = find_step(StepNames.symbol_set_step.value, tour.steps)
symbol_set_step = find_step(SN.symbol_set_step, tour.steps)
self.assertEqual(len(symbol_set_step.state["filelist_data"]), 4)
with patch_menu_prompt([(0, 1, 2, 3), (11), ()], multi=True):
symbol_set_step.run()
Expand All @@ -353,8 +344,8 @@ def test_wrong_fileformat_psv(self):
tour = Tour(
name="mismatched fileformat",
steps=[
dataset.FilelistStep(StepNames.filelist_step.value),
dataset.FilelistFormatStep(StepNames.filelist_format_step.value),
dataset.FilelistStep(),
dataset.FilelistFormatStep(),
],
)
filelist = str(self.data_dir / "unit-test-case2.psv")
Expand All @@ -377,8 +368,8 @@ def test_wrong_fileformat_festival(self):
tour = Tour(
name="mismatched fileformat",
steps=[
dataset.FilelistStep(StepNames.filelist_step.value),
dataset.FilelistFormatStep(StepNames.filelist_format_step.value),
dataset.FilelistStep(),
dataset.FilelistFormatStep(),
],
)
filelist = str(self.data_dir / "unit-test-case3.festival")
Expand Down Expand Up @@ -420,7 +411,7 @@ def test_validate_path(self):
validate_path(tmpdirname, is_dir=True, is_file=False, exists=False)
)
file_name = os.path.join(tmpdirname, "some-file-name")
with open(file_name, "w") as f:
with open(file_name, "w", encoding="utf8") as f:
f.write("foo")
self.assertTrue(
validate_path(file_name, is_dir=False, is_file=True, exists=True)
Expand Down Expand Up @@ -460,6 +451,58 @@ def test_prompt(self):
)
self.assertEqual(answer, 1)

def monkey_run_tour(self, name, steps):
tour = Tour(name, steps=[step for (step, *_) in steps])
self.assertEqual(tour.state, {}) # fail on accidentally shared initializer
for (step, answer, *_) in steps:
if isinstance(answer, Say):
monkey = monkeypatch(step, "prompt", answer)
else:
monkey = answer
# print(step.name)
with monkey:
step.run()
return tour

def test_monkey_tour_1(self):
with tempfile.TemporaryDirectory() as tmpdirname:
tour = self.monkey_run_tour(
"monkey tour 1",
[
(basic.NameStep(), Say("my-dataset-name")),
(basic.OutputPathStep(), Say(tmpdirname)),
],
)
self.assertEqual(tour.state[SN.name_step.value], "my-dataset-name")
self.assertEqual(tour.state[SN.output_step.value], tmpdirname)

def test_monkey_tour_2(self):
data_dir = Path(__file__).parent / "data"
tour = self.monkey_run_tour(
"monkey tour 2",
[
(dataset.WavsDirStep(), Say(data_dir)),
(
dataset.FilelistStep(),
Say(str(data_dir / "metadata.csv")),
),
(dataset.FilelistFormatStep(), Say("psv")),
(dataset.HasSpeakerStep(), Say("yes")),
(dataset.HasLanguageStep(), Say("yes")),
(dataset.SelectLanguageStep(), Say("eng")),
(dataset.TextProcessingStep(), Say([0, 1])),
(
dataset.SymbolSetStep(),
patch_menu_prompt([(0, 1, 2, 3, 4), (), ()], multi=True),
),
(dataset.SoxEffectsStep(), Say([0])),
(dataset.DatasetNameStep(), Say("my-monkey-dataset")),
],
)

# print(tour.state)
self.assertEqual(len(tour.state["filelist_data"]), 6)


if __name__ == "__main__":
main()
26 changes: 14 additions & 12 deletions everyvoice/wizard/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import List
from typing import List, Optional, Union

from anytree import NodeMixin, RenderTree
from questionary import Style
Expand Down Expand Up @@ -64,7 +64,7 @@ class Step(_Step, NodeMixin):

def __init__(
self,
name,
name: Union[None, Enum, str] = None,
default=None,
prompt_method=None,
validate_method=None,
Expand All @@ -74,18 +74,20 @@ def __init__(
state_subset=None,
):
super(Step, self).__init__()
self.name = name
if name is None:
name = getattr(self, "DEFAULT_NAME", "default step name missing")
self.name: str = name.value if isinstance(name, Enum) else name
self.default = default
self.parent = parent
self.state_subset = state_subset
self.state = None # should be added when the Step is added to a Tour
self.tour = None # should be added when the Step is added to a Tour
if effect_method is not None:
self.effect = effect_method
self.effect = effect_method # type: ignore[method-assign]
if prompt_method is not None:
self.prompt = prompt_method
self.prompt = prompt_method # type: ignore[method-assign]
if validate_method is not None:
self.validate = validate_method
self.validate = validate_method # type: ignore[method-assign]
if children:
self.children = children

Expand All @@ -110,17 +112,17 @@ def run(self):


class Tour:
def __init__(self, name: str, steps: List[Step], state: dict = {}):
def __init__(self, name: str, steps: List[Step], state: Optional[dict] = None):
"""Create the tour by setting each Step as the child of the previous Step."""
self.name = name
self.state = state
self.state: dict = state if state is not None else {}
for parent, child in zip(steps, steps[1:]):
child.parent = parent
self.determine_state(child, self.state)
child.tour = self
child.tour = self # type: ignore
self.steps = steps
self.root = steps[0]
self.root.tour = self
self.root.tour = self # type: ignore
self.determine_state(self.root, self.state)

def determine_state(self, step: Step, state: dict):
Expand All @@ -129,11 +131,11 @@ def determine_state(self, step: Step, state: dict):
state[step.state_subset] = {}
step.state = state[step.state_subset]
else:
step.state = state
step.state = state # type: ignore

def add_step(self, step: Step, parent: Step, child_index=0):
self.determine_state(step, self.state)
step.tour = self
step.tour = self # type: ignore
children = list(parent.children)
children.insert(child_index, step)
parent.children = children
Expand Down
Loading

0 comments on commit f824660

Please sign in to comment.