From 147c91211e3a259a42e08a2d5967bf220785558f Mon Sep 17 00:00:00 2001 From: zwwhdls Date: Tue, 7 Nov 2023 20:58:47 +0800 Subject: [PATCH] fix keywords Signed-off-by: zwwhdls --- pkg/friday/ingest_test.go | 2 +- pkg/friday/keywords.go | 9 ++++++++- pkg/friday/keywords_test.go | 2 ++ pkg/friday/question.go | 1 + pkg/friday/question_test.go | 2 +- pkg/friday/summary.go | 1 + pkg/friday/summary_test.go | 4 ++++ 7 files changed, 18 insertions(+), 3 deletions(-) diff --git a/pkg/friday/ingest_test.go b/pkg/friday/ingest_test.go index 5c1fe69..89f77cc 100644 --- a/pkg/friday/ingest_test.go +++ b/pkg/friday/ingest_test.go @@ -34,7 +34,7 @@ var _ = Describe("TestIngest", func() { BeforeEach(func() { loFriday.Vector = FakeStore{} - loFriday.Log = logger.NewLogger("test") + loFriday.Log = logger.NewLogger("test-ingest") loFriday.Spliter = spliter.NewTextSpliter(spliter.DefaultChunkSize, spliter.DefaultChunkOverlap, "\n") loFriday.Embedding = FakeEmbedding{} }) diff --git a/pkg/friday/keywords.go b/pkg/friday/keywords.go index 03f65ff..7761ff0 100644 --- a/pkg/friday/keywords.go +++ b/pkg/friday/keywords.go @@ -31,5 +31,12 @@ func (f *Friday) Keywords(content string) (keywords []string, err error) { } answer := answers[0] keywords = strings.Split(answer, " ") - return keywords, nil + result := []string{} + for _, keyword := range keywords { + if len(keyword) != 0 { + result = append(result, keyword) + } + } + f.Log.Debugf("Keywords result: %v", result) + return result, nil } diff --git a/pkg/friday/keywords_test.go b/pkg/friday/keywords_test.go index af84569..2b08ee4 100644 --- a/pkg/friday/keywords_test.go +++ b/pkg/friday/keywords_test.go @@ -22,6 +22,7 @@ import ( "github.com/basenana/friday/pkg/llm" "github.com/basenana/friday/pkg/llm/prompts" + "github.com/basenana/friday/pkg/utils/logger" ) var _ = Describe("TestKeywords", func() { @@ -31,6 +32,7 @@ var _ = Describe("TestKeywords", func() { BeforeEach(func() { loFriday.LLM = FakeKeyWordsLLM{} + loFriday.Log = logger.NewLogger("test-keywords") }) Context("keywords", func() { diff --git a/pkg/friday/question.go b/pkg/friday/question.go index 7ae4ce0..badc962 100644 --- a/pkg/friday/question.go +++ b/pkg/friday/question.go @@ -37,6 +37,7 @@ func (f *Friday) Question(q string) (string, error) { if err != nil { return "", fmt.Errorf("llm completion error: %w", err) } + f.Log.Debugf("Question result: %s", c) return ans[0], nil } return c, nil diff --git a/pkg/friday/question_test.go b/pkg/friday/question_test.go index 67c0efc..1e63fcd 100644 --- a/pkg/friday/question_test.go +++ b/pkg/friday/question_test.go @@ -36,7 +36,7 @@ var _ = Describe("TestQuestion", func() { BeforeEach(func() { loFriday.Vector = FakeStore{} - loFriday.Log = logger.NewLogger("test") + loFriday.Log = logger.NewLogger("test-question") loFriday.Spliter = spliter.NewTextSpliter(spliter.DefaultChunkSize, spliter.DefaultChunkOverlap, "\n") loFriday.Embedding = FakeQuestionEmbedding{} loFriday.LLM = FakeQuestionLLM{} diff --git a/pkg/friday/summary.go b/pkg/friday/summary.go index 630a3be..69be604 100644 --- a/pkg/friday/summary.go +++ b/pkg/friday/summary.go @@ -41,6 +41,7 @@ func (f *Friday) Summary(elements []models.Element, summaryType summary.SummaryT } result[source] = summaryOfFile } + f.Log.Debugf("Summary result: %s", result) return result, nil } diff --git a/pkg/friday/summary_test.go b/pkg/friday/summary_test.go index 797b2d4..115dd95 100644 --- a/pkg/friday/summary_test.go +++ b/pkg/friday/summary_test.go @@ -25,6 +25,7 @@ import ( "github.com/basenana/friday/pkg/llm/prompts" "github.com/basenana/friday/pkg/models" "github.com/basenana/friday/pkg/spliter" + "github.com/basenana/friday/pkg/utils/logger" ) var _ = Describe("TestStuffSummary", func() { @@ -36,6 +37,7 @@ var _ = Describe("TestStuffSummary", func() { ) BeforeEach(func() { + loFriday.Log = logger.NewLogger("test-stuffsummary") loFriday.LLM = FakeSummaryLLM{} loFriday.Spliter = spliter.NewTextSpliter(spliter.DefaultChunkSize, spliter.DefaultChunkOverlap, "\n") elements = []models.Element{{ @@ -80,6 +82,7 @@ var _ = Describe("TestMapReduceSummary", func() { ) BeforeEach(func() { + loFriday.Log = logger.NewLogger("test-mapreduce-summary") loFriday.LLM = FakeSummaryLLM{} loFriday.LimitToken = 4 loFriday.Spliter = spliter.NewTextSpliter(spliter.DefaultChunkSize, spliter.DefaultChunkOverlap, "\n") @@ -125,6 +128,7 @@ var _ = Describe("TestRefineSummary", func() { ) BeforeEach(func() { + loFriday.Log = logger.NewLogger("test-refine-summary") loFriday.LLM = FakeSummaryLLM{} loFriday.Spliter = spliter.NewTextSpliter(spliter.DefaultChunkSize, spliter.DefaultChunkOverlap, "\n") elements = []models.Element{{