Skip to content

Commit

Permalink
Merge pull request #8 from Kor-SVS/dev
Browse files Browse the repository at this point in the history
통계 기능 추가 및 버그 수정 v0.1.5
  • Loading branch information
Cardroid committed Sep 8, 2022
2 parents 3ce6044 + 503c3bd commit 57e4594
Show file tree
Hide file tree
Showing 9 changed files with 206 additions and 37 deletions.
4 changes: 3 additions & 1 deletion src/enunu_kor_tool/analysis4vb/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def get_root_module_logger():
logger.info(L("ustx -> ust 변환 중..."))
os.makedirs(db_config.output.temp, exist_ok=True)
for ustx_path in (db_raw_ustx_files_tqdm := tqdm(db_raw_ustx_files)):
db_raw_ustx_files_tqdm.set_description(f"ustx -> ust Converting... [{os.path.relpath(ustx_path)}]")
db_raw_ustx_files_tqdm.set_description(f"ustx -> ust Converting... [{ustx_path}]")
if (ustx_path_split := os.path.splitext(ustx_path))[1] == ".ustx":
converter = Ustx2Ust_Converter(ustx_path, encoding="utf-8")
converter.save_ust(os.path.join(db_config.output.temp, os.path.basename(ustx_path_split[0]) + ".ust"))
Expand Down Expand Up @@ -119,6 +119,8 @@ def get_root_module_logger():
if os.path.exists(db_info.config.output.temp):
shutil.rmtree(db_info.config.output.temp)

del config

logger.info("Cleaned up.")


Expand Down
4 changes: 3 additions & 1 deletion src/enunu_kor_tool/analysis4vb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
},
"options": {
"log_level": "info",
"encoding": "utf-8",
"lab_encoding": "utf-8",
"ust_encoding": "cp932",
"use_100ns": False,
"graph_save": True,
"graph_show": False,
"graph_darkmode": True,
Expand Down
1 change: 1 addition & 0 deletions src/enunu_kor_tool/analysis4vb/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def join_module_name(func_name: str):
"lab_error_check": {"module": join_module_name("lab"), "func": "lab_error_check"},
"phoneme_count": {"module": join_module_name("lab"), "func": "phoneme_count"},
"phoneme_length": {"module": join_module_name("lab"), "func": "phoneme_length"},
"phoneme_average_length": {"module": join_module_name("lab"), "func": "phoneme_average_length"},
"ust_error_check": {"module": join_module_name("ust"), "func": "ust_error_check"},
"pitch_note_count": {"module": join_module_name("ust"), "func": "pitch_note_count"},
"pitch_note_length": {"module": join_module_name("ust"), "func": "pitch_note_length"},
Expand Down
189 changes: 170 additions & 19 deletions src/enunu_kor_tool/analysis4vb/functions/lab.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import os
from typing import Dict
from typing import Dict, List

from tqdm import tqdm

Expand Down Expand Up @@ -38,7 +38,8 @@ def __lab_loader(db_info: DB_Info, logger: logging.Logger) -> bool:
"""

phonemes_files = db_info.files.lab
encoding = db_info.config.options.get("encoding", "utf-8")
use_100ns = db_info.config.options.get("use_100ns", False)
encoding = db_info.config.options.get("lab_encoding", "utf-8")
line_num_formatter = lambda ln: str(ln).rjust(4)

error_flag = False
Expand All @@ -47,7 +48,7 @@ def __lab_loader(db_info: DB_Info, logger: logging.Logger) -> bool:
lab_global_error_line_count = 0
for file in (file_tqdm := tqdm(phonemes_files, leave=False)):
file_tqdm.set_description(f"Processing... [{file}]")
logger.info(L("[{filepath}] 파일 로드 중...", filepath=os.path.relpath(file)))
logger.info(L("[{filepath}] 파일 로드 중...", filepath=file))

lab = []

Expand Down Expand Up @@ -96,7 +97,7 @@ def __lab_loader(db_info: DB_Info, logger: logging.Logger) -> bool:
L(
"lab 파일을 로드했습니다. [총 라인 수: {line_num}] [길이: {round_length}s ({length} 100ns)] [오류 라인 수: {error_line_count}]",
line_num=line_num_formatter(lab_len),
round_length=round(__100ns2s(length), 1),
round_length=length if use_100ns else round(__100ns2s(length), 1),
length=length,
error_line_count=error_line_count,
)
Expand All @@ -110,7 +111,9 @@ def __lab_loader(db_info: DB_Info, logger: logging.Logger) -> bool:
lab_global_error_line_count=lab_global_error_line_count,
)
)
db_info.cache["labs"] = labs

if labs != None and len(labs) > 0:
db_info.cache["labs"] = labs

return error_flag

Expand All @@ -122,7 +125,11 @@ def lab_error_check(db_info: DB_Info, logger: logging.Logger):


@__preprocess
def phoneme_count(db_info: DB_Info, logger: logging.Logger):
def phoneme_count(db_info: DB_Info, logger: logging.Logger, quiet_mode: bool = False):
if "labs" not in db_info.cache:
logger.error(L("로드된 Lab 파일을 찾을 수 없습니다."))
return

config_group = db_info.config.group
is_show_graph = db_info.config.options["graph_show"]
is_save_graph = db_info.config.options["graph_save"]
Expand Down Expand Up @@ -158,7 +165,7 @@ def add_one(dic: Dict, name: str):
else:
add_one(group_phoneme_count_dict, "error")

if is_show_graph or is_save_graph:
if not quiet_mode and (is_show_graph or is_save_graph):
logger.info(L("그래프 출력 중..."))
graph_path = db_info.config.output.graph
graph_show_dpi = db_info.config.options["graph_show_dpi"]
Expand Down Expand Up @@ -219,8 +226,13 @@ def add_one(dic: Dict, name: str):


@__preprocess
def phoneme_length(db_info: DB_Info, logger: logging.Logger):
def phoneme_length(db_info: DB_Info, logger: logging.Logger, quiet_mode: bool = False):
if "labs" not in db_info.cache:
logger.error(L("로드된 Lab 파일을 찾을 수 없습니다."))
return

config_group = db_info.config.group
use_100ns = db_info.config.options.get("use_100ns", False)
is_show_graph = db_info.config.options["graph_show"]
is_save_graph = db_info.config.options["graph_save"]
labs: Dict = db_info.cache["labs"]
Expand All @@ -236,15 +248,15 @@ def phoneme_length(db_info: DB_Info, logger: logging.Logger):

def add_one(dic: Dict, name: str, length: int):
if name in dic:
dic[name] += __100ns2s(length)
dic[name] += length if use_100ns else __100ns2s(length)
else:
dic[name] = __100ns2s(length)
dic[name] = length if use_100ns else __100ns2s(length)

def add_one_list(dic: Dict, name: str, length: int):
def add_one_list(dic: Dict[str, List[str]], name: str, length: int):
if name in dic:
dic[name].append(__100ns2s(length))
dic[name].append(length if use_100ns else __100ns2s(length))
else:
dic[name] = [__100ns2s(length)]
dic[name] = [length if use_100ns else __100ns2s(length)]

for file, lab in (labs_tqdm := tqdm(labs.items(), leave=False)):
labs_tqdm.set_description(f"[{file}] Calculating...")
Expand All @@ -264,10 +276,11 @@ def add_one_list(dic: Dict, name: str, length: int):
else:
add_one(group_phoneme_length_dict, "error", end - start)

if is_show_graph or is_save_graph:
if not quiet_mode and (is_show_graph or is_save_graph):
logger.info(L("그래프 출력 중..."))
graph_path = db_info.config.output.graph
graph_show_dpi = db_info.config.options["graph_show_dpi"]
unit_suffix = "(100ns)" if use_100ns else "(s)"

utils.matplotlib_init(db_info.config.options["graph_darkmode"])
from matplotlib import pyplot as plt
Expand All @@ -283,9 +296,9 @@ def add_one_list(dic: Dict, name: str, length: int):
single_phoneme_length_sorted_dict = dict(sorted(single_phoneme_length_dict.items(), key=lambda item: item[1], reverse=True))
keys, values = list(single_phoneme_length_sorted_dict.keys()), list(single_phoneme_length_sorted_dict.values())
b1 = plt.bar(keys, values, width=0.7)
plt.bar_label(b1, fmt="%.1fs")
plt.bar_label(b1)

plt.gca().yaxis.set_major_formatter(mticker.FormatStrFormatter("%.2fs"))
plt.gca().yaxis.set_major_formatter(mticker.FormatStrFormatter(f"%.2f{unit_suffix}"))

plt.title(L("Phonemes Length Statistics"))
plt.xlabel(L("Phoneme"))
Expand All @@ -312,9 +325,9 @@ def add_one_list(dic: Dict, name: str, length: int):
values.insert(1, total_length_except_silence)

b1 = plt.bar(keys, values, width=0.7)
plt.bar_label(b1, fmt="%.1fs")
plt.bar_label(b1)

plt.gca().yaxis.set_major_formatter(mticker.FormatStrFormatter("%.2fs"))
plt.gca().yaxis.set_major_formatter(mticker.FormatStrFormatter(f"%.2f{unit_suffix}"))

plt.title(L("Phonemes Length Statistics by Group"))
plt.xlabel(L("Phoneme Group"))
Expand Down Expand Up @@ -347,7 +360,7 @@ def add_one_list(dic: Dict, name: str, length: int):
x = [i + random.uniform(-scatter_range, scatter_range) for _ in range(len(v))]
plt.scatter(x, v)

plt.gca().yaxis.set_major_formatter(mticker.FormatStrFormatter("%.2fs"))
plt.gca().yaxis.set_major_formatter(mticker.FormatStrFormatter(f"%.2f{unit_suffix}"))

plt.title(L("Phonemes Length Statistics [Box plot]"))
plt.xlabel(L("Phoneme"))
Expand All @@ -361,3 +374,141 @@ def add_one_list(dic: Dict, name: str, length: int):

db_info.stats["phoneme_length"] = phoneme_length_dict
return phoneme_length_dict


@__preprocess
def phoneme_average_length(db_info: DB_Info, logger: logging.Logger, quiet_mode: bool = False):
if "labs" not in db_info.cache:
logger.error(L("로드된 Lab 파일을 찾을 수 없습니다."))
return

is_show_graph = db_info.config.options["graph_show"]
is_save_graph = db_info.config.options["graph_save"]

if "phoneme_count" not in db_info.stats:
phoneme_count(db_info, logger, True)
if "phoneme_length" not in db_info.stats:
phoneme_length(db_info, logger, True)

phoneme_count_stats: Dict[str, Dict[str, int]] = db_info.stats["phoneme_count"]
phoneme_count_stats_group = phoneme_count_stats["group"]
phoneme_count_stats_single = phoneme_count_stats["single"]
phoneme_length_stats: Dict[str, Dict[str, float]] = db_info.stats["phoneme_length"]
phoneme_length_stats_group = phoneme_length_stats["group"]
phoneme_length_stats_single = phoneme_length_stats["single"]

group_phoneme_average_length_dict = {}
single_phoneme_average_length_dict = {}
phoneme_average_length_dict = {
"group": group_phoneme_average_length_dict,
"single": single_phoneme_average_length_dict,
}

def add_one(dic: Dict, name: str, length: float):
if name in dic:
dic[name] += length
else:
dic[name] = length

for phn in phoneme_count_stats_single.keys():
add_one(single_phoneme_average_length_dict, phn, phoneme_length_stats_single[phn] / phoneme_count_stats_single[phn])

for key in ["consonant", "vowel", "silence", "other", "error"]:
if key in phoneme_length_stats_group and key in phoneme_count_stats_group:
add_one(group_phoneme_average_length_dict, key, phoneme_length_stats_group[key] / phoneme_count_stats_group[key])

# from pprint import pprint

# pprint(single_phoneme_average_length_dict)

if not quiet_mode and (is_show_graph or is_save_graph):
config_group = db_info.config.group
logger.info(L("그래프 출력 중..."))
graph_path = db_info.config.output.graph
graph_show_dpi = db_info.config.options["graph_show_dpi"]
use_100ns = db_info.config.options.get("use_100ns", False)
unit_suffix = "(100ns)" if use_100ns else "(s)"

utils.matplotlib_init(db_info.config.options["graph_darkmode"])
from matplotlib import pyplot as plt
import matplotlib.ticker as mticker

#####
# * # 단일 음소 평균 길이 그래프
#####
plot_name = "phoneme_average_length_single"
plt.figure(utils.get_plot_num(plot_name), figsize=(16, 8), dpi=graph_show_dpi)

single_phoneme_average_length_sorted_dict = dict(sorted(single_phoneme_average_length_dict.items(), key=lambda item: item[1], reverse=True))
keys, values = list(single_phoneme_average_length_sorted_dict.keys()), list(single_phoneme_average_length_sorted_dict.values())
b1 = plt.bar(keys, values, width=0.7)
plt.bar_label(b1, rotation=45)

plt.gca().yaxis.set_major_formatter(mticker.FormatStrFormatter(f"%.2f{unit_suffix}"))

plt.title(L("Phonemes Average Length Statistics"))
plt.xlabel(L("Phoneme"))
plt.ylabel(L("Length"))
plt.tight_layout()

if is_save_graph:
plt.savefig(os.path.join(graph_path, f"{plot_name}.jpg"), dpi=200)
if is_show_graph:
plt.show(block=False)

#####
# * # 무음 제외 단일 음소 평균 길이 그래프
#####
plot_name = "phoneme_average_length_single_except_silence"
plt.figure(utils.get_plot_num(plot_name), figsize=(16, 8), dpi=graph_show_dpi)

for rest_phn in config_group.silence:
if rest_phn in single_phoneme_average_length_sorted_dict:
del single_phoneme_average_length_sorted_dict[rest_phn]
keys, values = list(single_phoneme_average_length_sorted_dict.keys()), list(single_phoneme_average_length_sorted_dict.values())
b1 = plt.bar(keys, values, width=0.7)
plt.bar_label(b1, rotation=45)

plt.gca().yaxis.set_major_formatter(mticker.FormatStrFormatter(f"%.2f{unit_suffix}"))

plt.title(L("Phonemes Average Length Statistics (except silence)"))
plt.xlabel(L("Phoneme"))
plt.ylabel(L("Length"))
plt.tight_layout()

if is_save_graph:
plt.savefig(os.path.join(graph_path, f"{plot_name}.jpg"), dpi=200)
if is_show_graph:
plt.show(block=False)

#####
# * # 그룹 음소 평균 길이 그래프
#####
plot_name = "phoneme_average_length_group"
plt.figure(utils.get_plot_num(plot_name), dpi=graph_show_dpi)

keys, values = list(group_phoneme_average_length_dict.keys()), list(group_phoneme_average_length_dict.values())
total_average_length = sum(values)
total_average_length_except_silence = total_average_length - group_phoneme_average_length_dict["silence"]
keys.insert(0, "Total")
values.insert(0, total_average_length)
keys.insert(1, "Total\n(except silence)")
values.insert(1, total_average_length_except_silence)

b1 = plt.bar(keys, values, width=0.7)
plt.bar_label(b1)

plt.gca().yaxis.set_major_formatter(mticker.FormatStrFormatter(f"%.2f{unit_suffix}"))

plt.title(L("Phonemes Average Length Statistics by Group"))
plt.xlabel(L("Phoneme Group"))
plt.ylabel(L("Length"))
plt.tight_layout()

if is_save_graph:
plt.savefig(os.path.join(graph_path, f"{plot_name}.jpg"), dpi=200)
if is_show_graph:
plt.show(block=False)

db_info.stats["phoneme_average_length"] = phoneme_average_length_dict
return phoneme_average_length_dict
17 changes: 14 additions & 3 deletions src/enunu_kor_tool/analysis4vb/functions/ust.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def __ust_loader(db_info: DB_Info, logger: logging.Logger) -> bool:

ust_files = db_info.files.ust
group_config = db_info.config.group
encoding = db_info.config.options.get("encoding", "utf-8")
encoding = db_info.config.options.get("ust_encoding", "cp932")
print(encoding)
line_num_formatter = lambda ln: str(ln).rjust(4)

error_flag = False
Expand All @@ -50,7 +51,7 @@ def __ust_loader(db_info: DB_Info, logger: logging.Logger) -> bool:

for file in (file_tqdm := tqdm(ust_files, leave=False)):
file_tqdm.set_description(f"Processing... [{file}]")
logger.info(L("[{filepath}] 파일 로드 중...", filepath=os.path.relpath(file)))
logger.info(L("[{filepath}] 파일 로드 중...", filepath=file))

ust = up.ust.load(file, encoding=encoding)

Expand Down Expand Up @@ -93,7 +94,9 @@ def __ust_loader(db_info: DB_Info, logger: logging.Logger) -> bool:
global_notes_voiced_length_sum=round(global_notes_voiced_length_sum, 3),
)
)
db_info.cache["usts"] = usts

if usts != None and len(usts) > 0:
db_info.cache["usts"] = usts

return error_flag

Expand All @@ -106,6 +109,10 @@ def ust_error_check(db_info: DB_Info, logger: logging.Logger):

@__preprocess
def pitch_note_count(db_info: DB_Info, logger: logging.Logger):
if "usts" not in db_info.cache:
logger.error(L("로드된 Ust, Ustx 파일을 찾을 수 없습니다."))
return

config_group = db_info.config.group
is_show_graph = db_info.config.options["graph_show"]
is_save_graph = db_info.config.options["graph_save"]
Expand Down Expand Up @@ -166,6 +173,10 @@ def pitch_note_count(db_info: DB_Info, logger: logging.Logger):

@__preprocess
def pitch_note_length(db_info: DB_Info, logger: logging.Logger):
if "usts" not in db_info.cache:
logger.error(L("로드된 Ust, Ustx 파일을 찾을 수 없습니다."))
return

config_group = db_info.config.group
is_show_graph = db_info.config.options["graph_show"]
is_save_graph = db_info.config.options["graph_save"]
Expand Down
Loading

0 comments on commit 57e4594

Please sign in to comment.