Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: update qa.csv col #567

Merged
merged 1 commit into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions apiserver/docs/docs.go
Original file line number Diff line number Diff line change
Expand Up @@ -1173,13 +1173,28 @@ const docTemplate = `{
"type": "string",
"example": "旷工最小计算单位为 0.5 天。"
},
"file_path": {
"description": "file fullpath",
"content": {
"description": "related content in the source file",
"type": "string",
"example": "旷工最小计算单位为0.5天,不足0.5天以0.5天计算,超过0.5天不满1天以1天计算,以此类推。"
},
"file_name": {
"description": "source file name, only file name, not full path",
"type": "string",
"example": "员工考勤管理制度-2023.pdf"
},
"page_number": {
"description": "page number in the source file",
"type": "integer",
"example": 1
},
"qa_file_path": {
"description": "the qa file fullpath",
"type": "string",
"example": "dataset/dataset-playground/v1/qa.csv"
},
"line_number": {
"description": "line number in the file",
"qa_line_number": {
"description": "line number in the qa file",
"type": "integer",
"example": 7
},
Expand Down
23 changes: 19 additions & 4 deletions apiserver/docs/swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -1162,13 +1162,28 @@
"type": "string",
"example": "旷工最小计算单位为 0.5 天。"
},
"file_path": {
"description": "file fullpath",
"content": {
"description": "related content in the source file",
"type": "string",
"example": "旷工最小计算单位为0.5天,不足0.5天以0.5天计算,超过0.5天不满1天以1天计算,以此类推。"
},
"file_name": {
"description": "source file name, only file name, not full path",
"type": "string",
"example": "员工考勤管理制度-2023.pdf"
},
"page_number": {
"description": "page number in the source file",
"type": "integer",
"example": 1
},
"qa_file_path": {
"description": "the qa file fullpath",
"type": "string",
"example": "dataset/dataset-playground/v1/qa.csv"
},
"line_number": {
"description": "line number in the file",
"qa_line_number": {
"description": "line number in the qa file",
"type": "integer",
"example": 7
},
Expand Down
20 changes: 16 additions & 4 deletions apiserver/docs/swagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,24 @@ definitions:
description: Answer row
example: 旷工最小计算单位为 0.5 天。
type: string
file_path:
description: file fullpath
content:
description: related content in the source file
example: 旷工最小计算单位为0.5天,不足0.5天以0.5天计算,超过0.5天不满1天以1天计算,以此类推。
type: string
file_name:
description: source file name, only file name, not full path
example: 员工考勤管理制度-2023.pdf
type: string
page_number:
description: page number in the source file
example: 1
type: integer
qa_file_path:
description: the qa file fullpath
example: dataset/dataset-playground/v1/qa.csv
type: string
line_number:
description: line number in the file
qa_line_number:
description: line number in the qa file
example: 7
type: integer
question:
Expand Down
2 changes: 1 addition & 1 deletion controllers/base/knowledgebase_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ func (r *KnowledgeBaseReconciler) handleFile(ctx context.Context, log logr.Logge
loader = documentloaders.NewText(dataReader)
case ".csv":
if v == arcadiav1alpha1.ObjectTypeQA {
loader = pkgdocumentloaders.NewQACSV(dataReader, fileName, "q", "a")
loader = pkgdocumentloaders.NewQACSV(dataReader, fileName)
documents, err = loader.Load(ctx)
if err != nil {
return err
Expand Down
74 changes: 60 additions & 14 deletions pkg/appruntime/retriever/knowledgebaseretriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
apiretriever "github.com/kubeagi/arcadia/api/app-node/retriever/v1alpha1"
"github.com/kubeagi/arcadia/api/base/v1alpha1"
"github.com/kubeagi/arcadia/pkg/appruntime/base"
"github.com/kubeagi/arcadia/pkg/documentloaders"
"github.com/kubeagi/arcadia/pkg/langchainwrap"
pkgvectorstore "github.com/kubeagi/arcadia/pkg/vectorstore"
)
Expand All @@ -47,10 +48,16 @@ type Reference struct {
Answer string `json:"answer" example:"旷工最小计算单位为 0.5 天。"`
// vector search score
Score float32 `json:"score" example:"0.34"`
// file fullpath
FilePath string `json:"file_path" example:"dataset/dataset-playground/v1/qa.csv"`
// line number in the file
LineNumber int `json:"line_number" example:"7"`
// the qa file fullpath
QAFilePath string `json:"qa_file_path" example:"dataset/dataset-playground/v1/qa.csv"`
// line number in the qa file
QALineNumber int `json:"qa_line_number" example:"7"`
// source file name, only file name, not full path
FileName string `json:"file_name" example:"员工考勤管理制度-2023.pdf"`
// page number in the source file
PageNumber int `json:"page_number" example:"1"`
// related content in the source file
Content string `json:"content" example:"旷工最小计算单位为0.5天,不足0.5天以0.5天计算,超过0.5天不满1天以1天计算,以此类推。"`
}

func (reference Reference) String() string {
Expand Down Expand Up @@ -177,23 +184,62 @@ func (c *KnowledgeBaseStuffDocuments) joinDocuments(ctx context.Context, docs []
logger.V(3).Info(fmt.Sprintf("KnowledgeBaseRetriever: related doc[%d] metadata[%s]: %#v\n", k, key, v))
}
}
answer, _ := doc.Metadata["a"].([]byte)
// chroma will get []byte, pgvector will get string...
answer, ok := doc.Metadata[documentloaders.AnswerCol].(string)
if !ok {
if a, ok := doc.Metadata[documentloaders.AnswerCol].([]byte); ok {
answer = strings.TrimPrefix(strings.TrimSuffix(string(a), "\""), "\"")
}
}

text += doc.PageContent
if len(answer) != 0 {
text = text + "\na: " + strings.TrimPrefix(strings.TrimSuffix(string(answer), "\""), "\"")
text = text + "\na: " + answer
}
if k != docLen-1 {
text += c.Separator
}
filepath, _ := doc.Metadata["fileName"].([]byte)
lineNumber, _ := doc.Metadata["lineNumber"].([]byte)
line, _ := strconv.Atoi(string(lineNumber))
qafilepath, ok := doc.Metadata[documentloaders.QAFileName].(string)
if !ok {
if a, ok := doc.Metadata[documentloaders.QAFileName].([]byte); ok {
qafilepath = strings.TrimPrefix(strings.TrimSuffix(string(a), "\""), "\"")
}
}
lineNumber, ok := doc.Metadata[documentloaders.LineNumber].(string)
if !ok {
if a, ok := doc.Metadata[documentloaders.LineNumber].([]byte); ok {
lineNumber = strings.TrimPrefix(strings.TrimSuffix(string(a), "\""), "\"")
}
}
line, _ := strconv.Atoi(lineNumber)
filename, ok := doc.Metadata[documentloaders.FileNameCol].(string)
if !ok {
if a, ok := doc.Metadata[documentloaders.FileNameCol].([]byte); ok {
filename = strings.TrimPrefix(strings.TrimSuffix(string(a), "\""), "\"")
}
}
pageNumber, ok := doc.Metadata[documentloaders.PageNumberCol].(string)
if !ok {
if a, ok := doc.Metadata[documentloaders.PageNumberCol].([]byte); ok {
pageNumber = strings.TrimPrefix(strings.TrimSuffix(string(a), "\""), "\"")
}
}
page, _ := strconv.Atoi(pageNumber)
content, ok := doc.Metadata[documentloaders.ChunkContentCol].(string)
if !ok {
if a, ok := doc.Metadata[documentloaders.ChunkContentCol].([]byte); ok {
content = strings.TrimPrefix(strings.TrimSuffix(string(a), "\""), "\"")
}
}
c.References = append(c.References, Reference{
Question: doc.PageContent,
Answer: strings.TrimPrefix(strings.TrimSuffix(string(answer), "\""), "\""),
Score: doc.Score,
FilePath: strings.TrimPrefix(strings.TrimSuffix(string(filepath), "\""), "\""),
LineNumber: line,
Question: doc.PageContent,
Answer: answer,
Score: doc.Score,
QAFilePath: qafilepath,
QALineNumber: line,
FileName: filename,
PageNumber: page,
Content: content,
})
}
logger.V(3).Info(fmt.Sprintf("KnowledgeBaseRetriever: finally get related text: %s\n", text))
Expand Down
100 changes: 73 additions & 27 deletions pkg/documentloaders/qa_csv.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,44 +6,89 @@ import (
"errors"
"fmt"
"io"
"strconv"
"strings"

"github.com/tmc/langchaingo/documentloaders"
"github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/textsplitter"
)

const (
QuestionCol = "q"
AnswerCol = "a"
FileNameCol = "file_name"
PageNumberCol = "page_number"
ChunkContentCol = "chunk_content"
LineNumber = "line_number"
QAFileName = "qafile_name"
)

// QACSV represents a QA CSV document loader.
type QACSV struct {
r io.Reader
fileName string
questionColumn string
answerColumn string
r io.Reader
fileName string
questionColumn string
answerColumn string
fileNameColumn string
pageNumberColumn string
chunkContentColumn string
}

var _ documentloaders.Loader = QACSV{}

// NewQACSV creates a new qa csv loader with an io.Reader and optional column names for filtering.
func NewQACSV(r io.Reader, fileName string, questionColumn string, answerColumn string) QACSV {
if questionColumn == "" {
questionColumn = "q"
// Option is a function type that can be used to modify the client.
type Option func(p *QACSV)

func WithQuestionColumn(s string) Option {
return func(p *QACSV) {
p.questionColumn = s
}
}
func WithAnswerColumn(s string) Option {
return func(p *QACSV) {
p.answerColumn = s
}
}
func WithFileNameColumn(s string) Option {
return func(p *QACSV) {
p.fileNameColumn = s
}
}
func WithPageNumberColumn(s string) Option {
return func(p *QACSV) {
p.pageNumberColumn = s
}
if answerColumn == "" {
answerColumn = "a"
}
func WithChunkContentColumn(s string) Option {
return func(p *QACSV) {
p.chunkContentColumn = s
}
return QACSV{
r: r,
fileName: fileName,
questionColumn: questionColumn,
answerColumn: answerColumn,
}

// NewQACSV creates a new qa csv loader with an io.Reader and optional column names for filtering.
func NewQACSV(r io.Reader, fileName string, opts ...Option) QACSV {
q := QACSV{
r: r,
fileName: fileName,
questionColumn: QuestionCol,
answerColumn: AnswerCol,
fileNameColumn: FileNameCol,
pageNumberColumn: PageNumberCol,
chunkContentColumn: ChunkContentCol,
}
for _, opt := range opts {
opt(&q)
}
return q
}

// Load reads from the io.Reader and returns a single document with the data.
func (c QACSV) Load(_ context.Context) ([]schema.Document, error) {
var header []string
var docs []schema.Document
var rown int
cols := []string{c.questionColumn, c.answerColumn, c.fileNameColumn, c.pageNumberColumn, c.chunkContentColumn}

rd := csv.NewReader(c.r)
for {
Expand All @@ -58,22 +103,23 @@ func (c QACSV) Load(_ context.Context) ([]schema.Document, error) {
header = append(header, row...)
continue
}

doc := schema.Document{}
doc.Metadata = make(map[string]any, len(cols)-1)
for i, value := range row {
if c.questionColumn != "" && header[i] != c.questionColumn && header[i] != c.answerColumn {
continue
}
value = strings.TrimSpace(value)
if header[i] == c.questionColumn {
switch header[i] {
case c.questionColumn:
doc.PageContent = fmt.Sprintf("%s: %s", header[i], value)
}
if header[i] == c.answerColumn {
doc.Metadata = map[string]any{
c.answerColumn: value,
"fileName": c.fileName,
"lineNumber": rown,
}
case c.answerColumn:
doc.Metadata[AnswerCol] = value
doc.Metadata[QAFileName] = c.fileName
doc.Metadata[LineNumber] = strconv.Itoa(rown)
case c.pageNumberColumn:
doc.Metadata[PageNumberCol] = value
case c.fileNameColumn:
doc.Metadata[FileNameCol] = value
case c.chunkContentColumn:
doc.Metadata[ChunkContentCol] = value
}
}
rown++
Expand Down
27 changes: 17 additions & 10 deletions pkg/documentloaders/qa_csv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,36 @@ func TestCSVLoader(t *testing.T) {
file, err := os.Open(fileName)
assert.NoError(t, err)

loader := NewQACSV(file, fileName, "q", "a")
loader := NewQACSV(file, fileName)

docs, err := loader.Load(context.Background())
require.NoError(t, err)
require.Len(t, docs, 25)
require.Len(t, docs, 9)

expected1PageContent := "q: 什么是员工考勤管理制度?"
expectedFileName := "员工考勤管理制度-2023.pdf"
expected1PageContent := "q: 员工在病假期间,是否有额外的假期?"
assert.Equal(t, docs[0].PageContent, expected1PageContent)

expected1Metadata := map[string]any{
"a": "该制度旨在严格工作纪律、提高工作效率,规范公司考勤管理,为公司考勤管理提供明确依据。",
"fileName": fileName,
"lineNumber": 0,
AnswerCol: "无法确定,题目未提及。",
QAFileName: fileName,
LineNumber: "0",
FileNameCol: expectedFileName,
ChunkContentCol: "3员工请假时间小于等于2天,由直接上级、部门负责人审批,人力资源部备案;员工请假时间大于等于3天,依次由直接上级、部门负责人、公司管理层审批,人力资源部备案。二.事假1、申请事假须至少提前1天在钉钉上发起请假申请,经直属领导逐级审批通过后,抄送人力资源部备案。事假最小计算单位为0.5天,不足0.5天以0.5天计算,以此类推。2、如遇特殊情况未能事前申请,须于当日10:00前电话或其他有效方式告知直属上级和人力资源部,且在事后1日内在钉钉补充完成事假申请审批手续。3、事假理由不充分或有碍工作进度,公司可不予准假。一年内累计不能超过20天。事假扣除事假相应天数工资,期间无其他奖金、福利和补助。三.病假1、因病不能正常上班,需病假者。病假申请(急诊、门诊)审批要求和流程同事假审批,2天以上(含2天)病假需提供医院有效的病假证明。2、一年享有3天带薪病假。若3<正常病假天数≤60,日工资按合同工资50%计算;正常病假累计天数>",
PageNumberCol: "3",
}
assert.Equal(t, docs[0].Metadata, expected1Metadata)

expected2PageContent := "q: 该制度适用于哪些员工?"
expected2PageContent := "q: 公司的考勤管理制度适用于哪些人员?"
assert.Equal(t, docs[1].PageContent, expected2PageContent)

expected2Metadata := map[string]any{
"a": "适用于公司全体正式员工及实习生。",
"fileName": fileName,
"lineNumber": 1,
AnswerCol: "公司全体正式员工及实习生。",
QAFileName: fileName,
LineNumber: "1",
FileNameCol: expectedFileName,
ChunkContentCol: "1第一章总则一、目的为了严格工作纪律、提高工作效率,规范公司考勤管理,为公司考勤管理提供明确依据,现根据国家及当地地区相关法律法规,特制定本制度。二、使用范围1、本制度适用公司全体正式员工及实习生。2、员工应严格遵守工作律及考勤规章制度。各部门负责人在权限范围内有审批部门员工考勤记录的权利和严肃考勤纪律的义务,并以身作则,规范执行。3、人力资源部负责考勤信息的记录、汇总,监督考勤制度的执行。第二章工时制度及考勤方式1、考勤时间:1)公司执行五天弹性工作制,上班时间为9:00-9:30,下班时间为18:00-18:30,中午12:00-13:00为午休时间,不计入工作时间;每天工作时间不少于8小时。2)公司考虑交通通勤情况,每天上班给予10分钟延迟;9:40后为迟到打卡,每月最多迟到3次(不晚于10:00),超出则视为旷工;晚于10:00打卡且无正当理由,视为旷工半天;3)因工作原因下班晚走2小时,第二天打卡时间不晚于上午10:00,考勤打卡数据将作为员工日常管理和薪资核算的重要依据。",
PageNumberCol: "1",
}
assert.Equal(t, docs[1].Metadata, expected2Metadata)
}
Loading