Skip to content

Commit

Permalink
Merge pull request #40 from basenana/feature/api
Browse files Browse the repository at this point in the history
update: unified api
  • Loading branch information
zwwhdls committed Mar 19, 2024
2 parents 542015f + 4999197 commit a34e28c
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 85 deletions.
11 changes: 6 additions & 5 deletions cmd/apps/keywords.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,13 @@ var KeywordsCmd = &cobra.Command{
}

func keywords(content string) error {
a, usage, err := friday.Fri.Keywords(context.TODO(), content)
if err != nil {
return err
res := friday.KeywordsState{}
f := friday.Fri.WithContext(context.TODO()).Content(content).Keywords(&res)
if f.Error != nil {
return f.Error
}
fmt.Println("Answer: ")
fmt.Println(a)
fmt.Printf("Usage: %v", usage)
fmt.Println(res.Keywords)
fmt.Printf("Usage: %v", res.Tokens)
return nil
}
11 changes: 6 additions & 5 deletions cmd/apps/summary.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,13 @@ func init() {
}

func summary(ps string) error {
a, usage, err := friday.Fri.SummaryFromOriginFile(context.TODO(), ps, fridaysummary.SummaryType(summaryType))
if err != nil {
return err
res := friday.SummaryState{}
f := friday.Fri.WithContext(context.TODO()).OriginFile(&ps).OfType(fridaysummary.SummaryType(summaryType)).Summary(&res)
if f.Error != nil {
return f.Error
}
fmt.Println("Answer: ")
fmt.Println(a)
fmt.Printf("Usage: %v", usage)
fmt.Println(res.Summary)
fmt.Printf("Usage: %v", res.Tokens)
return nil
}
19 changes: 18 additions & 1 deletion pkg/friday/friday.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"

"github.com/basenana/friday/pkg/embedding"
"github.com/basenana/friday/pkg/friday/summary"
"github.com/basenana/friday/pkg/llm"
"github.com/basenana/friday/pkg/models"
"github.com/basenana/friday/pkg/spliter"
Expand Down Expand Up @@ -66,11 +67,17 @@ type Statement struct {
query *models.DocQuery
info string

// for ingest
// for ingest or summary
file *models.File // a whole file providing models.File
elementFile *string // a whole file given an element-style origin file
originFile *string // a whole file given an origin file
elements []models.Element

// for keywords
content string

// for summary
summaryType summary.SummaryType
}

type ChatState struct {
Expand All @@ -79,6 +86,16 @@ type ChatState struct {
Tokens map[string]int
}

type KeywordsState struct {
Keywords []string
Tokens map[string]int
}

type SummaryState struct {
Summary map[string]string
Tokens map[string]int
}

type IngestState struct {
Tokens map[string]int
}
Expand Down
16 changes: 11 additions & 5 deletions pkg/friday/keywords.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,22 @@
package friday

import (
"context"
"strings"

"github.com/basenana/friday/pkg/llm/prompts"
)

func (f *Friday) Keywords(ctx context.Context, content string) ([]string, map[string]int, error) {
func (f *Friday) Content(content string) *Friday {
f.statement.content = content
return f
}

func (f *Friday) Keywords(res *KeywordsState) *Friday {
prompt := prompts.NewKeywordsPrompt(f.Prompts[keywordsPromptKey])

answers, usage, err := f.LLM.Completion(ctx, prompt, map[string]string{"context": content})
answers, usage, err := f.LLM.Completion(f.statement.context, prompt, map[string]string{"context": f.statement.content})
if err != nil {
return []string{}, nil, err
return f
}
answer := answers[0]
keywords := strings.Split(answer, ",")
Expand All @@ -39,5 +43,7 @@ func (f *Friday) Keywords(ctx context.Context, content string) ([]string, map[st
}
}
f.Log.Debugf("Keywords result: %v", result)
return result, usage, nil
res.Keywords = result
res.Tokens = usage
return f
}
7 changes: 4 additions & 3 deletions pkg/friday/keywords_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@ var _ = Describe("TestKeywords", func() {

Context("keywords", func() {
It("keywords should be succeed", func() {
keywords, _, err := loFriday.Keywords(context.TODO(), "test")
Expect(err).Should(BeNil())
Expect(keywords).Should(Equal([]string{"a", "b", "c"}))
res := KeywordsState{}
f := loFriday.WithContext(context.TODO()).Content("test").Keywords(&res)
Expect(f.Error).Should(BeNil())
Expect(res.Keywords).Should(Equal([]string{"a", "b", "c"}))
})
})
})
Expand Down
110 changes: 58 additions & 52 deletions pkg/friday/summary.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,73 +17,79 @@
package friday

import (
"context"

"github.com/basenana/friday/pkg/friday/summary"
"github.com/basenana/friday/pkg/models"
"github.com/basenana/friday/pkg/utils/files"
)

func (f *Friday) Summary(ctx context.Context, elements []models.Element, summaryType summary.SummaryType) (map[string]string, map[string]int, error) {
result := make(map[string]string)
s := summary.NewSummary(f.Log, f.LLM, f.LimitToken, f.Prompts)
func (f *Friday) OfType(summaryType summary.SummaryType) *Friday {
f.statement.summaryType = summaryType
return f
}

docs := make(map[string][]string)
for _, element := range elements {
if _, ok := docs[element.Name]; !ok {
docs[element.Name] = []string{element.Content}
} else {
docs[element.Name] = append(docs[element.Name], element.Content)
}
func (f *Friday) Summary(res *SummaryState) *Friday {
if f.statement.summaryType == "" {
f.statement.summaryType = summary.Stuff
}
totalUsage := make(map[string]int)
for source, doc := range docs {
summaryOfFile, usage, err := s.Summary(ctx, doc, summaryType)
res.Summary = map[string]string{}
res.Tokens = map[string]int{}
// init
s := summary.NewSummary(f.Log, f.LLM, f.LimitToken, f.Prompts)

if f.statement.file != nil {
// split doc
docs := f.Spliter.Split(f.statement.file.Content)
// summary
summaryOfFile, usage, err := s.Summary(f.statement.context, docs, f.statement.summaryType)
if err != nil {
return nil, nil, err
}
result[source] = summaryOfFile
for k, v := range usage {
totalUsage[k] = totalUsage[k] + v
f.Error = err
return f
}
res.Summary = map[string]string{f.statement.file.Name: summaryOfFile}
res.Tokens = usage
return f
}
f.Log.Debugf("Summary result: %s", result)
return result, totalUsage, nil
}

func (f *Friday) SummaryFromFile(ctx context.Context, file models.File, summaryType summary.SummaryType) (map[string]string, map[string]int, error) {
s := summary.NewSummary(f.Log, f.LLM, f.LimitToken, f.Prompts)
// split doc
docs := f.Spliter.Split(file.Content)
// summary
summaryOfFile, usage, err := s.Summary(ctx, docs, summaryType)
if err != nil {
return nil, nil, err
}
return map[string]string{file.Name: summaryOfFile}, usage, err
}
if f.statement.originFile != nil {
fs, err := files.ReadFiles(*f.statement.originFile)
if err != nil {
f.Error = err
return f
}

func (f *Friday) SummaryFromOriginFile(ctx context.Context, ps string, summaryType summary.SummaryType) (map[string]string, map[string]int, error) {
s := summary.NewSummary(f.Log, f.LLM, f.LimitToken, f.Prompts)
fs, err := files.ReadFiles(ps)
if err != nil {
return nil, nil, err
for name, file := range fs {
// split doc
subDocs := f.Spliter.Split(file)
summaryOfFile, usage, err := s.Summary(f.statement.context, subDocs, f.statement.summaryType)
if err != nil {
f.Error = err
return f
}
res.Summary[name] = summaryOfFile
res.Tokens = mergeTokens(usage, res.Tokens)
}
return f
}

result := make(map[string]string)
totalUsage := make(map[string]int)
for name, file := range fs {
// split doc
subDocs := f.Spliter.Split(file)
summaryOfFile, usage, err := s.Summary(ctx, subDocs, summaryType)
if err != nil {
return nil, nil, err
if len(f.statement.elements) != 0 {
docs := make(map[string][]string)
for _, element := range f.statement.elements {
if _, ok := docs[element.Name]; !ok {
docs[element.Name] = []string{element.Content}
} else {
docs[element.Name] = append(docs[element.Name], element.Content)
}
}
result[name] = summaryOfFile
for k, v := range usage {
totalUsage[k] = totalUsage[k] + v
for source, doc := range docs {
summaryOfFile, usage, err := s.Summary(f.statement.context, doc, f.statement.summaryType)
if err != nil {
f.Error = err
return f
}
res.Summary[source] = summaryOfFile
res.Tokens = mergeTokens(usage, res.Tokens)
}
return f
}

return result, totalUsage, nil
return f
}
34 changes: 20 additions & 14 deletions pkg/friday/summary_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,18 @@ var _ = Describe("TestStuffSummary", func() {

Context("summary", func() {
It("summary should be succeed", func() {
summary, _, err := loFriday.Summary(context.TODO(), elements, summaryType)
Expect(err).Should(BeNil())
Expect(summary).Should(Equal(map[string]string{
res := SummaryState{}
f := loFriday.WithContext(context.TODO()).Element(elements).OfType(summaryType).Summary(&res)
Expect(f.Error).Should(BeNil())
Expect(res.Summary).Should(Equal(map[string]string{
"test-title": "a b c",
}))
})
It("SummaryFromFile should be succeed", func() {
summary, _, err := loFriday.SummaryFromFile(context.TODO(), file, summaryType)
Expect(err).Should(BeNil())
Expect(summary).Should(Equal(map[string]string{
res := SummaryState{}
f := loFriday.WithContext(context.TODO()).File(&file).OfType(summaryType).Summary(&res)
Expect(f.Error).Should(BeNil())
Expect(res.Summary).Should(Equal(map[string]string{
"test-file": "a b c",
}))
})
Expand Down Expand Up @@ -97,16 +99,18 @@ var _ = Describe("TestMapReduceSummary", func() {

Context("summary", func() {
It("summary should be succeed", func() {
summary, _, err := loFriday.Summary(context.TODO(), elements, summaryType)
Expect(err).Should(BeNil())
Expect(summary).Should(Equal(map[string]string{
res := SummaryState{}
f := loFriday.WithContext(context.TODO()).Element(elements).OfType(summaryType).Summary(&res)
Expect(f.Error).Should(BeNil())
Expect(res.Summary).Should(Equal(map[string]string{
"test-title": "a b c",
}))
})
It("SummaryFromFile should be succeed", func() {
summary, _, err := loFriday.SummaryFromFile(context.TODO(), file, summaryType)
Expect(err).Should(BeNil())
Expect(summary).Should(Equal(map[string]string{
res := SummaryState{}
f := loFriday.WithContext(context.TODO()).File(&file).OfType(summaryType).Summary(&res)
Expect(f.Error).Should(BeNil())
Expect(res.Summary).Should(Equal(map[string]string{
"test-file": "a b c",
}))
})
Expand Down Expand Up @@ -138,11 +142,13 @@ var _ = Describe("TestRefineSummary", func() {

Context("summary", func() {
It("summary should be succeed", func() {
_, _, _ = loFriday.Summary(context.TODO(), elements, summaryType)
res := SummaryState{}
_ = loFriday.WithContext(context.TODO()).Element(elements).OfType(summaryType).Summary(&res)
// todo
})
It("SummaryFromFile should be succeed", func() {
_, _, _ = loFriday.SummaryFromFile(context.TODO(), file, summaryType)
res := SummaryState{}
_ = loFriday.WithContext(context.TODO()).File(&file).OfType(summaryType).Summary(&res)
// todo
})
})
Expand Down

0 comments on commit a34e28c

Please sign in to comment.