Skip to content

Commit

Permalink
Merge pull request #980 from wangxinbiao/main
Browse files Browse the repository at this point in the history
fix:vectorize in batches.
  • Loading branch information
bjwswang committed Apr 9, 2024
2 parents 2533a90 + 6709c29 commit aa43dde
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 116 deletions.
17 changes: 14 additions & 3 deletions pypi/data-processing/src/data_store_process/minio_store_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ async def text_manipulate(
return {"status": 400, "message": str(ex), "data": traceback.format_exc()}


def text_manipulate_retry(req_json, pool):
async def text_manipulate_retry(req_json, pool):
task_id = req_json.get("id")
creator = req_json.get("creator")
log_id = ulid.ulid()
Expand Down Expand Up @@ -470,7 +470,7 @@ def text_manipulate_retry(req_json, pool):
]
)
)
result = _text_manipulate_retry_for_document(
result = await _text_manipulate_retry_for_document(
document=document,
task_info=task_info_dict,
log_id=log_id,
Expand Down Expand Up @@ -937,7 +937,7 @@ def _insert_log_info(id, task_id, execute_type, creator, pool):
return {"status": 400, "message": str(ex), "data": traceback.format_exc()}


def _text_manipulate_retry_for_document(document, task_info, log_id, pool, creator):
async def _text_manipulate_retry_for_document(document, task_info, log_id, pool, creator):
file_name = document.get("file_name")
task_id = task_info.get("id")
document_id = document.get("id")
Expand Down Expand Up @@ -1025,6 +1025,16 @@ def _text_manipulate_retry_for_document(document, task_info, log_id, pool, creat
task_id=task_id,
create_user=creator,
)
elif file_extension == "web":
# 处理.web文件
result = await web_handle.web_manipulate(
file_name=file_name,
document_id=item.get("document_id"),
support_type=support_type,
conn_pool=pool,
task_id=id,
create_user=req_json["creator"],
)

# 将下载的本地文件删除
_remove_local_file(file_name)
Expand All @@ -1042,6 +1052,7 @@ def _text_manipulate_retry_for_document(document, task_info, log_id, pool, creat
file_name=file_name,
all_document_for_process=document_chunk_dict.get("data"),
support_type=support_type,
progress=int(document.get("progress")),
conn_pool=pool,
create_user=creator,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,35 @@ def update_document_progress(req_json, pool):
return res


def update_document_status_and_progress(req_json, pool):
"""Update the status and progress with id"""
now = date_time_utils.now_str()
program = "文件处理完成-修改"

params = {
"id": req_json["id"],
"status": req_json["status"],
"end_time": now,
"progress": req_json["progress"],
"update_datetime": now,
"update_program": program,
}

sql = """
update public.data_process_task_document set
status = %(status)s,
end_time = %(end_time)s,
progress = %(progress)s,
update_datetime = %(update_datetime)s,
update_program = %(update_program)s
where
id = %(id)s
""".strip()

res = postgresql_pool_client.execute_update(pool, sql, params)
return res


def list_file_by_task_id(req_json, pool):
"""info with id"""
params = {"task_id": req_json["task_id"]}
Expand Down
170 changes: 153 additions & 17 deletions pypi/data-processing/src/file_handle/common_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@


def text_manipulate(
all_document_for_process, file_name, support_type, conn_pool, create_user
all_document_for_process,
file_name,
support_type,
conn_pool,
create_user,
progress=0
):
"""Manipulate the text content.
Expand All @@ -63,7 +68,7 @@ def text_manipulate(
conn_pool=conn_pool,
)

text_process_success_num = 0
text_process_success_num = progress
for document in all_document_for_process:
document_chunk_id = document.get("id")
# Clean the data such as removing invisible characters.
Expand Down Expand Up @@ -116,11 +121,6 @@ def text_manipulate(
if qa_response.get("status") != 200:
return qa_response

# 文件处理成功,更新data_process_task_document中的文件状态
_updata_document_status_and_end_time(
id=document_id, status="success", conn_pool=conn_pool
)

if support_type_map.get("qa_split"):
# 是否选择了QA拆分
qa_list_dict = support_type_map.get("qa_split")
Expand Down Expand Up @@ -196,6 +196,13 @@ def text_manipulate(
file_name=file_name_csv, phase_value="final", data=qa_data_dict
)

_update_document_status_and_progress(
id=document_id,
status="success",
progress=100,
conn_pool=conn_pool
)

logger.debug(f"{log_tag_const.COMMON_HANDLE} Finish manipulating the text")
return {
"status": 200,
Expand Down Expand Up @@ -225,13 +232,25 @@ def text_manipulate(
file_name=file_name_csv, phase_value="final", data=chunk_data_dict
)

_update_document_status_and_progress(
id=document_id,
status="success",
progress=100,
conn_pool=conn_pool
)

logger.debug(f"{log_tag_const.COMMON_HANDLE} Finish manipulating the text")
return {
"status": 200,
"message": "",
"data": "",
}

# 文件处理成功,更新data_process_task_document中的文件状态
_updata_document_status_and_end_time(
id=document_id, status="success", conn_pool=conn_pool
)

return {"status": 200, "message": "", "data": ""}
except Exception as ex:
logger.error(
Expand Down Expand Up @@ -914,6 +933,7 @@ def _qa_split(
):
qa_list_dict = support_type_map.get("qa_split")
llm_config = qa_list_dict.get("llm_config")
remove_duplicate_config = qa_list_dict.get("remove_duplicate_config")

# 更新chunk状态为开始
_update_document_chunk_status_and_start_time(
Expand All @@ -937,6 +957,7 @@ def _qa_split(
id=document_id, status="fail", conn_pool=conn_pool
)
else:
qa_list = []
# 将QA数据存入表中
qa_data = qa_response.get("data")
for _, item in enumerate(qa_data):
Expand All @@ -955,6 +976,34 @@ def _qa_split(
qa_insert_item, pool=conn_pool
)

qa_list.append(qa_insert_item)

# 是否需要进行去重
if remove_duplicate_config:
for qa in qa_list:
embedding_response = _embedding_qa(
qa_list=[qa],
remove_duplicate_config=remove_duplicate_config,
conn_pool=conn_pool
)

if embedding_response.get("status") != 200:
# 处理失败
# 更新data_process_task_document_chunk中的状态
_updata_document_chunk_status_and_end_time(
id=document_chunk_id,
update_user=create_user,
status="fail",
conn_pool=conn_pool,
)

# 更新data_process_task_document中的文件状态
_updata_document_status_and_end_time(
id=document_id, status="fail", conn_pool=conn_pool
)

return embedding_response

# 更新data_process_task_document_chunk中的状态
_updata_document_chunk_status_and_end_time(
id=document_chunk_id,
Expand All @@ -965,6 +1014,9 @@ def _qa_split(

# 更新文件处理进度
progress = int(text_process_success_num / document_chunk_size * 100)
if text_process_success_num == document_chunk_size:
progress = 99

_updata_document_progress(
id=document_id,
progress=progress,
Expand Down Expand Up @@ -994,7 +1046,7 @@ def _generate_qa_list(content, llm_config):

# Generate the QA list.
qa_list = []
if llm_spec_info.get("data").get("provider").get("worker"):
if llm_config.get("provider") == "worker":
# get base url for configmap
base_url = model_cr.get_worker_base_url_k8s_configmap(
name=config.k8s_default_config, namespace=config.k8s_pod_namespace
Expand Down Expand Up @@ -1190,6 +1242,26 @@ def _updata_document_progress(id, progress, update_user, conn_pool):
return {"status": 1000, "message": str(ex), "data": traceback.format_exc()}


def _update_document_status_and_progress(id, status, progress, conn_pool):
try:
document_update_item = {"id": id, "status": status, "progress": progress}
data_process_document_db_operate.update_document_status_and_progress(
document_update_item, pool=conn_pool
)

return {"status": 200, "message": "", "data": ""}
except Exception as ex:
logger.error(
"".join(
[
f"{log_tag_const.COMMON_HANDLE} update document status ",
f"\n{traceback.format_exc()}",
]
)
)
return {"status": 1000, "message": str(ex), "data": traceback.format_exc()}


def _update_document_chunk_status_and_start_time(id, update_user, conn_pool):
try:
now = date_time_utils.now_str()
Expand Down Expand Up @@ -1292,8 +1364,8 @@ def _qa_remove_duplicate(qa_list, remove_duplicate_config, conn_pool):
provider = remove_duplicate_config.get("embedding_provider")
similarity = float(remove_duplicate_config.get("similarity"))

# llms cr 中模型相关信息
llm_spec_info = model_cr.get_spec_for_embedding_k8s_cr(name=name, namespace=namespace)
# embedding cr 中模型相关信息
embedding_spec_info = model_cr.get_spec_for_embedding_k8s_cr(name=name, namespace=namespace)

if provider == "worker":
# get base url for configmap
Expand All @@ -1319,11 +1391,11 @@ def _qa_remove_duplicate(qa_list, remove_duplicate_config, conn_pool):
)

remove_duplicate_loader = QARemoveDuplicate(embeddings=qa_embeddings, pool=conn_pool)
return remove_duplicate_loader.qa_remove_duplicate(qa_list, similarity)
return remove_duplicate_loader.remove_duplicate_qa_data(qa_list, similarity)
else:
endpoint = llm_spec_info.get("data").get("provider").get("endpoint")
endpoint = embedding_spec_info.get("data").get("provider").get("endpoint")
base_url = endpoint.get("url")
llm_type = llm_spec_info.get("data").get("type")
embedding_type = embedding_spec_info.get("data").get("type")

logger.debug(
"".join(
Expand All @@ -1332,19 +1404,83 @@ def _qa_remove_duplicate(qa_list, remove_duplicate_config, conn_pool):
f"name: {name}\n",
f"namespace: {namespace}\n",
f"model: {model}\n",
f"llm_type: {llm_type}\n",
f"embedding_type: {embedding_type}\n",
]
)
)

if embedding_type == "openai":
qa_embeddings = OpenAIEmbeddings(
api_key="fake",
base_url=base_url,
model=model,
)

remove_duplicate_loader = QARemoveDuplicate(embeddings=qa_embeddings, pool=conn_pool)
return remove_duplicate_loader.remove_duplicate_qa_data(qa_list, similarity)
else:
return {"status": 1000, "message": f"暂时不支持{embedding_type}类型的向量化模型模型", "data": ""}


def _embedding_qa(qa_list, remove_duplicate_config, conn_pool):
name = remove_duplicate_config.get("embedding_name")
namespace = remove_duplicate_config.get("embedding_namespace")
model = remove_duplicate_config.get("embedding_model")
provider = remove_duplicate_config.get("embedding_provider")

# embeddings cr 中模型相关信息
embedding_spec_info = model_cr.get_spec_for_embedding_k8s_cr(name=name, namespace=namespace)

if provider == "worker":
# get base url for configmap
base_url = model_cr.get_worker_base_url_k8s_configmap(
name=config.k8s_default_config, namespace=config.k8s_pod_namespace
)
logger.debug(
"".join(
[
f"worker embedding \n",
f"name: {name}\n",
f"namespace: {namespace}\n",
f"model: {model}\n",
f"base_url: {base_url}\n",
]
)
)

qa_embeddings = OpenAIEmbeddings(
api_key="fake",
base_url=base_url,
model=model,
)

remove_duplicate_loader = QARemoveDuplicate(embeddings=qa_embeddings, pool=conn_pool)
return remove_duplicate_loader.embedding_qa_data(qa_list)
else:
endpoint = embedding_spec_info.get("data").get("provider").get("endpoint")
base_url = endpoint.get("url")
embedding_type = embedding_spec_info.get("data").get("type")

logger.debug(
"".join(
[
f"3rd_party embedding \n",
f"name: {name}\n",
f"namespace: {namespace}\n",
f"model: {model}\n",
f"embedding_type: {embedding_type}\n",
]
)
)

if llm_type == "openai":
if embedding_type == "openai":
qa_embeddings = OpenAIEmbeddings(
api_key="fake",
base_url=base_url,
model=model,
)

remove_duplicate_loader = QARemoveDuplicate(embeddings=qa_embeddings, pool=conn_pool)
return remove_duplicate_loader.qa_remove_duplicate(qa_list, similarity)
return remove_duplicate_loader.embedding_qa_data(qa_list)
else:
return {"status": 1000, "message": f"暂时不支持{llm_type}类型的向量化模型模型", "data": ""}
return {"status": 1000, "message": f"暂时不支持{embedding_type}类型的向量化模型模型", "data": ""}
Loading

0 comments on commit aa43dde

Please sign in to comment.