Skip to content

Commit

Permalink
Merge pull request #21 from basenana/fix/map_reduce
Browse files Browse the repository at this point in the history
mapreduce: fix length
  • Loading branch information
zwwhdls committed Nov 19, 2023
2 parents 15ee8ec + 346f3b6 commit 2611340
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 9 deletions.
8 changes: 7 additions & 1 deletion pkg/friday/summary/map-reduce.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"strings"

"github.com/basenana/friday/pkg/llm/prompts"
"github.com/basenana/friday/pkg/utils/files"
)

func (s *Summary) MapReduce(docs []string) (summary string, err error) {
Expand Down Expand Up @@ -68,7 +69,12 @@ func (s *Summary) getLength(p prompts.PromptTemplate, docs []string) (length int
if err != nil {
return 0, err
}
return len(res), nil

ress := strings.Split(res, "\n")
for _, r := range ress {
length += files.Length(r)
}
return length, nil
}

func (s *Summary) mapSummaries(docs []string) ([]string, error) {
Expand Down
73 changes: 73 additions & 0 deletions pkg/friday/summary/map-reduce_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
Copyright 2023 Friday Author.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package summary

import (
"testing"

"github.com/basenana/friday/pkg/llm/prompts"
"github.com/basenana/friday/pkg/utils/logger"
)

func TestSummary_getLength(t *testing.T) {
type fields struct {
limitToken int
}
type args struct {
docs []string
}
tests := []struct {
name string
fields fields
args args
wantLength int
wantErr bool
}{
{
name: "test1",
fields: fields{
limitToken: 10,
},
args: args{
docs: []string{
"I am a doc",
"You are a doc too",
},
},
wantLength: 23,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &Summary{
log: logger.NewLogger("test"),
summaryPrompt: prompts.NewSummaryPrompt(),
combinePrompt: prompts.NewCombinePrompt(),
limitToken: tt.fields.limitToken,
}
gotLength, err := s.getLength(s.summaryPrompt, tt.args.docs)
if (err != nil) != tt.wantErr {
t.Errorf("getLength() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotLength != tt.wantLength {
t.Errorf("getLength() gotLength = %v, want %v", gotLength, tt.wantLength)
}
})
}
}
2 changes: 1 addition & 1 deletion pkg/llm/prompts/keywords.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type KeywordsPrompt struct {
Context string
}

const KeyWordsTemplate = `Extract keywords from the following, separated by comma and reply in zh-CN Language.:
const KeyWordsTemplate = `Extract keywords from the following, separated by comma and reply in zh-CN Language:
"{{ .Context }}"
Expand Down
9 changes: 2 additions & 7 deletions pkg/spliter/text.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"strings"

"github.com/basenana/friday/pkg/models"
"github.com/basenana/friday/pkg/utils/files"
"github.com/basenana/friday/pkg/utils/logger"
)

Expand Down Expand Up @@ -49,13 +50,7 @@ func NewTextSpliter(chunkSize int, chunkOverlap int, separator string) Spliter {
}

func (t *TextSpliter) length(d string) int {
// todo: it should be more accurate
// https://platform.openai.com/docs/guides/text-generation/managing-tokens
pured := strings.TrimSpace(d)
if pured == "" {
return 0
}
return len(strings.Split(strings.TrimSpace(pured), " "))
return files.Length(d)
}

func (t *TextSpliter) Split(text string) []string {
Expand Down
29 changes: 29 additions & 0 deletions pkg/utils/files/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
Copyright 2023 Friday Author.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package files

import "strings"

func Length(doc string) int {
// todo: it should be more accurate
// https://platform.openai.com/docs/guides/text-generation/managing-tokens
pured := strings.TrimSpace(doc)
if pured == "" {
return 0
}
return len(strings.Split(strings.TrimSpace(pured), " "))
}

0 comments on commit 2611340

Please sign in to comment.