diff --git a/everyvoice/tests/test_preprocessing.py b/everyvoice/tests/test_preprocessing.py index d834819f..8621399e 100755 --- a/everyvoice/tests/test_preprocessing.py +++ b/everyvoice/tests/test_preprocessing.py @@ -85,15 +85,13 @@ def test_read_filelist(self): self.assertNotIn("speaker", self.filelist[0].keys()) def test_process_audio_for_alignment(self): - self.config = EveryVoiceConfig() + config = EveryVoiceConfig() for entry in self.filelist[1:]: # This just applies the SOX effects audio, sr = self.preprocessor.process_audio( self.wavs_dir / (entry["filename"] + ".wav"), use_effects=True, - sox_effects=self.config.aligner.preprocessing.source_data[ - 0 - ].sox_effects, + sox_effects=config.aligner.preprocessing.source_data[0].sox_effects, ) self.assertEqual(sr, 22050) self.assertEqual(audio.dtype, float32) diff --git a/everyvoice/tests/test_wizard.py b/everyvoice/tests/test_wizard.py index e5504de2..01c8d419 100755 --- a/everyvoice/tests/test_wizard.py +++ b/everyvoice/tests/test_wizard.py @@ -4,6 +4,7 @@ import os import string import tempfile +from enum import Enum from pathlib import Path from types import MethodType from typing import Sequence @@ -22,7 +23,9 @@ patch_logger, patch_menu_prompt, ) -from everyvoice.wizard import Step, StepNames, Tour, basic, dataset, prompts, validators +from everyvoice.wizard import Step +from everyvoice.wizard import StepNames as SN +from everyvoice.wizard import Tour, basic, dataset, prompts, validators class WizardTest(TestCase): @@ -51,18 +54,16 @@ def test_config_format_effect(self): self.assertTrue(config_step.validate("json")) with tempfile.TemporaryDirectory() as tmpdirname: config_step.state = {} - config_step.state[StepNames.output_step.value] = tmpdirname - config_step.state[StepNames.name_step.value] = config_step.name + config_step.state[SN.output_step.value] = tmpdirname + config_step.state[SN.name_step.value] = config_step.name config_step.state["dataset_test"] = {} - config_step.state["dataset_test"][ - StepNames.symbol_set_step.value - ] = Symbols(symbol_set=string.ascii_letters) - config_step.state["dataset_test"][StepNames.wavs_dir_step.value] = ( + config_step.state["dataset_test"][SN.symbol_set_step.value] = Symbols( + symbol_set=string.ascii_letters + ) + config_step.state["dataset_test"][SN.wavs_dir_step.value] = ( Path(tmpdirname) / "test" ) - config_step.state["dataset_test"][ - StepNames.dataset_name_step.value - ] = "test" + config_step.state["dataset_test"][SN.dataset_name_step.value] = "test" config_step.state["dataset_test"]["filelist_data"] = [ {"basename": "0001", "text": "hello"}, {"basename": "0002", "text": "hello", None: "test"}, @@ -142,8 +143,8 @@ def test_output_path_step(self): tour = Tour( "testing", [ - basic.NameStep(StepNames.name_step.value), - basic.OutputPathStep(StepNames.output_step.value), + basic.NameStep(), + basic.OutputPathStep(), ], ) @@ -157,15 +158,14 @@ def test_output_path_step(self): with tempfile.TemporaryDirectory() as tmpdirname: file_path = os.path.join(tmpdirname, "exits-as-file") # Bad case 1: output dir exists and is a file - with open(file_path, "w") as f: + with open(file_path, "w", encoding="utf8") as f: f.write("blah") - print("blah", file=f) with patch_logger(basic) as logger, self.assertLogs(logger): self.assertFalse(step.validate(file_path)) # Bad case 2: file called the same as the dataset exists in the output dir dataset_file = os.path.join(tmpdirname, "myname") - with open(dataset_file, "w") as f: + with open(dataset_file, "w", encoding="utf8") as f: f.write("blah") with patch_logger(basic) as logger, self.assertLogs(logger): self.assertFalse(step.validate(tmpdirname)) @@ -182,9 +182,7 @@ def test_output_path_step(self): def test_more_data_step(self): """Exercise giving an invalid response and a yes response to more data.""" - tour = Tour( - "testing", [basic.MoreDatasetsStep(name=StepNames.more_datasets_step.value)] - ) + tour = Tour("testing", [basic.MoreDatasetsStep()]) step = tour.steps[0] self.assertFalse(step.validate("foo")) self.assertTrue(step.validate("yes")) @@ -199,7 +197,7 @@ def test_more_data_step(self): self.assertGreater(len(step.children), 5) def test_dataset_name(self): - step = dataset.DatasetNameStep("") + step = dataset.DatasetNameStep() with monkeypatch(builtins, "input", Say(("", "bad/name", "good-name"), True)): with patch_logger(dataset) as logger, self.assertLogs(logger) as logs: step.run() @@ -215,8 +213,8 @@ def test_wavs_dir(self): has_wavs_dir = os.path.join(tmpdirname, "there-are-wavs-here") os.mkdir(has_wavs_dir) - with open(os.path.join(has_wavs_dir, "foo.wav"), "w") as f: - f.write("A fantastic sounding clip! (or not...)") + with open(os.path.join(has_wavs_dir, "foo.wav"), "wb") as f: + f.write(b"A fantastic sounding clip! (or not...)") step = dataset.WavsDirStep("") with monkeypatch( @@ -253,27 +251,26 @@ def test_sample_rate_config(self): self.assertEqual(step.response, 512) def test_dataset_subtour(self): - def find_step(name: str, steps: Sequence[Step]): + def find_step(name: Enum, steps: Sequence[Step]): for s in steps: - if s.name == name: + if s.name == name.value: return s raise IndexError(f"Step {name} not found.") # pragma: no cover tour = Tour("unit testing", steps=dataset.return_dataset_steps()) filelist = str(self.data_dir / "unit-test-case1.psv") - filelist_step = find_step(StepNames.filelist_step.value, tour.steps) - with monkeypatch(filelist_step, "prompt", Say(filelist)): + filelist_step = find_step(SN.filelist_step, tour.steps) + monkey = monkeypatch(filelist_step, "prompt", Say(filelist)) + with monkey: filelist_step.run() - format_step = find_step(StepNames.filelist_format_step.value, tour.steps) + format_step = find_step(SN.filelist_format_step, tour.steps) with patch_menu_prompt(0): # 0 is "psv" format_step.run() self.assertIsInstance(format_step.children[0], dataset.HeaderStep) - self.assertEqual( - format_step.children[0].name, StepNames.basename_header_step.value - ) + self.assertEqual(format_step.children[0].name, SN.basename_header_step.value) self.assertIsInstance(format_step.children[1], dataset.HeaderStep) - self.assertEqual(format_step.children[1].name, StepNames.text_header_step.value) + self.assertEqual(format_step.children[1].name, SN.text_header_step.value) step = format_step.children[0] with patch_menu_prompt(1): # 1 is second column @@ -287,17 +284,13 @@ def find_step(name: str, steps: Sequence[Step]): # print(step.state["filelist_headers"]) self.assertEqual(step.state["filelist_headers"][2], "text") - speaker_step = find_step( - StepNames.data_has_speaker_value_step.value, tour.steps - ) + speaker_step = find_step(SN.data_has_speaker_value_step, tour.steps) children_before = len(speaker_step.children) with patch_menu_prompt(1): # 1 is "no" speaker_step.run() self.assertEqual(len(speaker_step.children), children_before) - language_step = find_step( - StepNames.data_has_language_value_step.value, tour.steps - ) + language_step = find_step(SN.data_has_language_value_step, tour.steps) children_before = len(language_step.children) with patch_menu_prompt(1): # 1 is "no" language_step.run() @@ -314,9 +307,7 @@ def find_step(name: str, steps: Sequence[Step]): ["unknown_0", "basename", "text", "unknown_3"], ) - text_processing_step = find_step( - StepNames.text_processing_step.value, tour.steps - ) + text_processing_step = find_step(SN.text_processing_step, tour.steps) # 0 is lowercase, 1 is NFC Normalization, select both with monkeypatch(dataset, "tqdm", lambda seq, desc: seq): with patch_menu_prompt([0, 1]): @@ -327,7 +318,7 @@ def find_step(name: str, steps: Sequence[Step]): "cased \t nfd: éàê nfc: éàê", # the "nfd: éàê" bit here is now NFC ) - sox_effects_step = find_step(StepNames.sox_effects_step.value, tour.steps) + sox_effects_step = find_step(SN.sox_effects_step, tour.steps) # 0 is resample to 22050 kHz, 2 is remove silence at start with patch_menu_prompt([0, 2]): sox_effects_step.run() @@ -337,7 +328,7 @@ def find_step(name: str, steps: Sequence[Step]): [["channel", "1"], ["rate", "22050"], ["silence", "1", "0.1", "1.0%"]], ) - symbol_set_step = find_step(StepNames.symbol_set_step.value, tour.steps) + symbol_set_step = find_step(SN.symbol_set_step, tour.steps) self.assertEqual(len(symbol_set_step.state["filelist_data"]), 4) with patch_menu_prompt([(0, 1, 2, 3), (11), ()], multi=True): symbol_set_step.run() @@ -353,8 +344,8 @@ def test_wrong_fileformat_psv(self): tour = Tour( name="mismatched fileformat", steps=[ - dataset.FilelistStep(StepNames.filelist_step.value), - dataset.FilelistFormatStep(StepNames.filelist_format_step.value), + dataset.FilelistStep(), + dataset.FilelistFormatStep(), ], ) filelist = str(self.data_dir / "unit-test-case2.psv") @@ -377,8 +368,8 @@ def test_wrong_fileformat_festival(self): tour = Tour( name="mismatched fileformat", steps=[ - dataset.FilelistStep(StepNames.filelist_step.value), - dataset.FilelistFormatStep(StepNames.filelist_format_step.value), + dataset.FilelistStep(), + dataset.FilelistFormatStep(), ], ) filelist = str(self.data_dir / "unit-test-case3.festival") @@ -420,7 +411,7 @@ def test_validate_path(self): validate_path(tmpdirname, is_dir=True, is_file=False, exists=False) ) file_name = os.path.join(tmpdirname, "some-file-name") - with open(file_name, "w") as f: + with open(file_name, "w", encoding="utf8") as f: f.write("foo") self.assertTrue( validate_path(file_name, is_dir=False, is_file=True, exists=True) @@ -460,6 +451,58 @@ def test_prompt(self): ) self.assertEqual(answer, 1) + def monkey_run_tour(self, name, steps): + tour = Tour(name, steps=[step for (step, *_) in steps]) + self.assertEqual(tour.state, {}) # fail on accidentally shared initializer + for (step, answer, *_) in steps: + if isinstance(answer, Say): + monkey = monkeypatch(step, "prompt", answer) + else: + monkey = answer + # print(step.name) + with monkey: + step.run() + return tour + + def test_monkey_tour_1(self): + with tempfile.TemporaryDirectory() as tmpdirname: + tour = self.monkey_run_tour( + "monkey tour 1", + [ + (basic.NameStep(), Say("my-dataset-name")), + (basic.OutputPathStep(), Say(tmpdirname)), + ], + ) + self.assertEqual(tour.state[SN.name_step.value], "my-dataset-name") + self.assertEqual(tour.state[SN.output_step.value], tmpdirname) + + def test_monkey_tour_2(self): + data_dir = Path(__file__).parent / "data" + tour = self.monkey_run_tour( + "monkey tour 2", + [ + (dataset.WavsDirStep(), Say(data_dir)), + ( + dataset.FilelistStep(), + Say(str(data_dir / "metadata.csv")), + ), + (dataset.FilelistFormatStep(), Say("psv")), + (dataset.HasSpeakerStep(), Say("yes")), + (dataset.HasLanguageStep(), Say("yes")), + (dataset.SelectLanguageStep(), Say("eng")), + (dataset.TextProcessingStep(), Say([0, 1])), + ( + dataset.SymbolSetStep(), + patch_menu_prompt([(0, 1, 2, 3, 4), (), ()], multi=True), + ), + (dataset.SoxEffectsStep(), Say([0])), + (dataset.DatasetNameStep(), Say("my-monkey-dataset")), + ], + ) + + # print(tour.state) + self.assertEqual(len(tour.state["filelist_data"]), 6) + if __name__ == "__main__": main() diff --git a/everyvoice/wizard/__init__.py b/everyvoice/wizard/__init__.py index a1805b6f..2db68d66 100644 --- a/everyvoice/wizard/__init__.py +++ b/everyvoice/wizard/__init__.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import List +from typing import List, Optional, Union from anytree import NodeMixin, RenderTree from questionary import Style @@ -64,7 +64,7 @@ class Step(_Step, NodeMixin): def __init__( self, - name, + name: Union[None, Enum, str] = None, default=None, prompt_method=None, validate_method=None, @@ -74,18 +74,20 @@ def __init__( state_subset=None, ): super(Step, self).__init__() - self.name = name + if name is None: + name = getattr(self, "DEFAULT_NAME", "default step name missing") + self.name: str = name.value if isinstance(name, Enum) else name self.default = default self.parent = parent self.state_subset = state_subset self.state = None # should be added when the Step is added to a Tour self.tour = None # should be added when the Step is added to a Tour if effect_method is not None: - self.effect = effect_method + self.effect = effect_method # type: ignore[method-assign] if prompt_method is not None: - self.prompt = prompt_method + self.prompt = prompt_method # type: ignore[method-assign] if validate_method is not None: - self.validate = validate_method + self.validate = validate_method # type: ignore[method-assign] if children: self.children = children @@ -110,17 +112,17 @@ def run(self): class Tour: - def __init__(self, name: str, steps: List[Step], state: dict = {}): + def __init__(self, name: str, steps: List[Step], state: Optional[dict] = None): """Create the tour by setting each Step as the child of the previous Step.""" self.name = name - self.state = state + self.state: dict = state if state is not None else {} for parent, child in zip(steps, steps[1:]): child.parent = parent self.determine_state(child, self.state) - child.tour = self + child.tour = self # type: ignore self.steps = steps self.root = steps[0] - self.root.tour = self + self.root.tour = self # type: ignore self.determine_state(self.root, self.state) def determine_state(self, step: Step, state: dict): @@ -129,11 +131,11 @@ def determine_state(self, step: Step, state: dict): state[step.state_subset] = {} step.state = state[step.state_subset] else: - step.state = state + step.state = state # type: ignore def add_step(self, step: Step, parent: Step, child_index=0): self.determine_state(step, self.state) - step.tour = self + step.tour = self # type: ignore children = list(parent.children) children.insert(child_index, step) parent.children = children diff --git a/everyvoice/wizard/basic.py b/everyvoice/wizard/basic.py index 44c4a55b..7fe78caf 100644 --- a/everyvoice/wizard/basic.py +++ b/everyvoice/wizard/basic.py @@ -22,6 +22,8 @@ class NameStep(Step): + DEFAULT_NAME = StepNames.name_step + def prompt(self): return input("What would you like to call this project? ") @@ -44,6 +46,8 @@ def effect(self): class OutputPathStep(Step): + DEFAULT_NAME = StepNames.output_step + def prompt(self): return questionary.path( "Where should the New Dataset Wizard save your files?", @@ -75,6 +79,8 @@ def effect(self): class ConfigFormatStep(Step): + DEFAULT_NAME = StepNames.config_format_step + def prompt(self): return get_response_from_menu_prompt( "Which format would you like to output the configuration to?", @@ -242,6 +248,8 @@ def effect(self): class MoreDatasetsStep(Step): + DEFAULT_NAME = StepNames.more_datasets_step + def prompt(self): return get_response_from_menu_prompt( "Do you have more datasets to process?", diff --git a/everyvoice/wizard/dataset.py b/everyvoice/wizard/dataset.py index ffb257c3..16acbb55 100644 --- a/everyvoice/wizard/dataset.py +++ b/everyvoice/wizard/dataset.py @@ -18,6 +18,8 @@ class DatasetNameStep(Step): + DEFAULT_NAME = StepNames.dataset_name_step + def prompt(self): return input("What would you like to call this dataset? ") @@ -41,6 +43,8 @@ def effect(self): class WavsDirStep(Step): + DEFAULT_NAME = StepNames.wavs_dir_step + def prompt(self): return questionary.path( "Where are your audio files?", style=CUSTOM_QUESTIONARY_STYLE @@ -56,6 +60,8 @@ def validate(self, response): class SampleRateConfigStep(Step): + DEFAULT_NAME = StepNames.sample_rate_config_step + def prompt(self): return questionary.text( "What is the sample rate (in Hertz) of your data?", @@ -79,6 +85,8 @@ def validate(self, response): class FilelistStep(Step): + DEFAULT_NAME = StepNames.filelist_step + def prompt(self): return questionary.path( "Where is your data filelist?", style=CUSTOM_QUESTIONARY_STYLE @@ -89,6 +97,7 @@ def validate(self, response): class FilelistFormatStep(Step): + DEFAULT_NAME = StepNames.filelist_format_step separators = {"psv": "|", "tsv": "\t", "csv": ","} def prompt(self): @@ -192,6 +201,8 @@ def effect(self): class HeaderStep(Step): + DEFAULT_NAME = StepNames.text_header_step + def __init__(self, name: str, prompt_text: str, header_name: str, **kwargs): super(HeaderStep, self).__init__(name=name, **kwargs) self.prompt_text = prompt_text @@ -234,6 +245,7 @@ def effect(self): class HasSpeakerStep(Step): + DEFAULT_NAME = StepNames.data_has_speaker_value_step choices = ("yes", "no") def prompt(self): @@ -262,6 +274,7 @@ def effect(self): class HasLanguageStep(Step): + DEFAULT_NAME = StepNames.data_has_language_value_step choices = ("yes", "no") def prompt(self): @@ -298,6 +311,8 @@ def effect(self): class SelectLanguageStep(Step): + DEFAULT_NAME = StepNames.select_language_step + def prompt(self): from g2p import get_arpabet_langs @@ -364,6 +379,8 @@ def return_symbols(language): class TextProcessingStep(Step): + DEFAULT_NAME = StepNames.text_processing_step + def prompt(self): return get_response_from_menu_prompt( prompt_text="Which of the following text transformations would like to apply before determining the symbol set?", @@ -386,16 +403,19 @@ def effect(self): } if self.response is not None and len(self.response): for process in self.response: + process_fn = process_lookup[process]["fn"] for i in tqdm( range(len(self.state["filelist_data"])), desc=f"Applying {process_lookup[process]['desc']} to data", ): - self.state["filelist_data"][i]["text"] = process_lookup[process][ - "fn" - ](self.state["filelist_data"][i]["text"]) + self.state["filelist_data"][i]["text"] = process_fn( + self.state["filelist_data"][i]["text"] + ) class SoxEffectsStep(Step): + DEFAULT_NAME = StepNames.sox_effects_step + def prompt(self): return get_response_from_menu_prompt( prompt_text="Which of the following audio preprocessing options would you like to apply?", @@ -427,6 +447,8 @@ def effect(self): class SymbolSetStep(Step): + DEFAULT_NAME = StepNames.symbol_set_step + def prompt(self): selected_language = get_iso_code( self.state.get(StepNames.select_language_step.value, None) @@ -501,42 +523,15 @@ def effect(self): def return_dataset_steps(dataset_index=0): return [ - WavsDirStep( - name=StepNames.wavs_dir_step.value, - state_subset=f"dataset_{dataset_index}", - ), - FilelistStep( - name=StepNames.filelist_step.value, - state_subset=f"dataset_{dataset_index}", - ), - FilelistFormatStep( - name=StepNames.filelist_format_step.value, - state_subset=f"dataset_{dataset_index}", - ), - HasSpeakerStep( - name=StepNames.data_has_speaker_value_step.value, - state_subset=f"dataset_{dataset_index}", - ), - HasLanguageStep( - name=StepNames.data_has_language_value_step.value, - state_subset=f"dataset_{dataset_index}", - ), - TextProcessingStep( - name=StepNames.text_processing_step.value, - state_subset=f"dataset_{dataset_index}", - ), - SymbolSetStep( - name=StepNames.symbol_set_step.value, - state_subset=f"dataset_{dataset_index}", - ), - SoxEffectsStep( - name=StepNames.sox_effects_step.value, - state_subset=f"dataset_{dataset_index}", - ), - DatasetNameStep( - name=StepNames.dataset_name_step.value, - state_subset=f"dataset_{dataset_index}", - ), + WavsDirStep(state_subset=f"dataset_{dataset_index}"), + FilelistStep(state_subset=f"dataset_{dataset_index}"), + FilelistFormatStep(state_subset=f"dataset_{dataset_index}"), + HasSpeakerStep(state_subset=f"dataset_{dataset_index}"), + HasLanguageStep(state_subset=f"dataset_{dataset_index}"), + TextProcessingStep(state_subset=f"dataset_{dataset_index}"), + SymbolSetStep(state_subset=f"dataset_{dataset_index}"), + SoxEffectsStep(state_subset=f"dataset_{dataset_index}"), + DatasetNameStep(state_subset=f"dataset_{dataset_index}"), ]