Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add limiter of openai #23

Merged
merged 1 commit into from
Nov 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cmd/apps/keywords.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package apps

import (
"context"
"fmt"

"github.com/spf13/cobra"
Expand All @@ -37,7 +38,7 @@ var KeywordsCmd = &cobra.Command{
}

func keywords(content string) error {
a, err := friday.Fri.Keywords(content)
a, err := friday.Fri.Keywords(context.TODO(), content)
if err != nil {
return err
}
Expand Down
3 changes: 2 additions & 1 deletion cmd/apps/question.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package apps

import (
"context"
"fmt"
"strings"

Expand All @@ -38,7 +39,7 @@ var QuestionCmd = &cobra.Command{
}

func run(question string) error {
a, err := friday.Fri.Question(question)
a, err := friday.Fri.Question(context.TODO(), question)
if err != nil {
return err
}
Expand Down
3 changes: 2 additions & 1 deletion cmd/apps/summary.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package apps

import (
"context"
"fmt"

"github.com/spf13/cobra"
Expand Down Expand Up @@ -46,7 +47,7 @@ func init() {
}

func summary(ps string) error {
a, err := friday.Fri.SummaryFromOriginFile(ps, fridaysummary.SummaryType(summaryType))
a, err := friday.Fri.SummaryFromOriginFile(context.TODO(), ps, fridaysummary.SummaryType(summaryType))
if err != nil {
return err
}
Expand Down
3 changes: 2 additions & 1 deletion cmd/apps/wechat.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package apps

import (
"context"
"fmt"
"strings"

Expand All @@ -38,7 +39,7 @@ var WeChatCmd = &cobra.Command{
}

func chat(ps string) error {
a, err := friday.Fri.ChatConclusionFromFile(ps)
a, err := friday.Fri.ChatConclusionFromFile(context.TODO(), ps)
if err != nil {
return err
}
Expand Down
7 changes: 4 additions & 3 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ type Config struct {
EmbeddingDim int `json:"embedding_dim,omitempty"` // embedding dimension, default is 1536

// LLM
LLMType LLMType `json:"llm_type"`
LLMUrl string `json:"llm_url,omitempty"` // only needed for glm-6b
LLMRateLimit int `json:"llm_rate_limit,omitempty"` // only needed for openai, rate_limit, in seconds, default is 60
LLMType LLMType `json:"llm_type"`
LLMUrl string `json:"llm_url,omitempty"` // only needed for glm-6b
LLMQueryPerMinute int `json:"llm_query_per_minute,omitempty"` // only needed for openai, qpm, default is 3
LLMBurst int `json:"llm_burst,omitempty"` // only needed for openai, burst, default is 5

// text spliter
SpliterChunkSize int `json:"spliter_chunk_size,omitempty"` // chunk of files splited to store, default is 4000
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ require (
golang.org/x/net v0.17.0 // indirect
golang.org/x/sys v0.13.0 // indirect
golang.org/x/text v0.13.0 // indirect
golang.org/x/time v0.4.0 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/time v0.4.0 h1:Z81tqI5ddIoXDPvVQ7/7CC9TnLM7ubaFG2qXYd5BbYY=
golang.org/x/time v0.4.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
Expand Down
4 changes: 2 additions & 2 deletions pkg/build/withvector/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func NewFridayWithVector(conf *config.Config, vectorClient vectorstore.VectorSto
if conf.OpenAIBaseUrl == "" {
conf.OpenAIBaseUrl = "https://api.openai.com"
}
llmClient = openaiv1.NewOpenAIV1(conf.OpenAIBaseUrl, conf.OpenAIKey, conf.LLMRateLimit)
llmClient = openaiv1.NewOpenAIV1(conf.OpenAIBaseUrl, conf.OpenAIKey, conf.LLMQueryPerMinute, conf.LLMBurst)
}
if conf.LLMType == config.LLMGLM6B {
llmClient = glm_6b.NewGLM(conf.LLMUrl)
Expand All @@ -51,7 +51,7 @@ func NewFridayWithVector(conf *config.Config, vectorClient vectorstore.VectorSto
if conf.OpenAIBaseUrl == "" {
conf.OpenAIBaseUrl = "https://api.openai.com"
}
embeddingModel = openaiembedding.NewOpenAIEmbedding(conf.OpenAIBaseUrl, conf.OpenAIKey, conf.LLMRateLimit)
embeddingModel = openaiembedding.NewOpenAIEmbedding(conf.OpenAIBaseUrl, conf.OpenAIKey, conf.LLMQueryPerMinute, conf.LLMBurst)
}
if conf.EmbeddingType == config.EmbeddingHuggingFace {
embeddingModel = huggingfaceembedding.NewHuggingFace(conf.EmbeddingUrl, conf.EmbeddingModel)
Expand Down
10 changes: 6 additions & 4 deletions pkg/embedding/openai/v1/embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package v1

import (
"context"

"github.com/basenana/friday/pkg/embedding"
"github.com/basenana/friday/pkg/llm/client/openai/v1"
)
Expand All @@ -27,14 +29,14 @@ type OpenAIEmbedding struct {

var _ embedding.Embedding = &OpenAIEmbedding{}

func NewOpenAIEmbedding(baseUrl, key string, rateLimit int) embedding.Embedding {
func NewOpenAIEmbedding(baseUrl, key string, qpm, burst int) embedding.Embedding {
return &OpenAIEmbedding{
OpenAIV1: v1.NewOpenAIV1(baseUrl, key, rateLimit),
OpenAIV1: v1.NewOpenAIV1(baseUrl, key, qpm, burst),
}
}

func (o *OpenAIEmbedding) VectorQuery(doc string) ([]float32, map[string]interface{}, error) {
res, err := o.Embedding(doc)
res, err := o.Embedding(context.TODO(), doc)
if err != nil {
return nil, nil, err
}
Expand All @@ -51,7 +53,7 @@ func (o *OpenAIEmbedding) VectorDocs(docs []string) ([][]float32, []map[string]i
metadata := make([]map[string]interface{}, len(docs))

for i, doc := range docs {
r, err := o.Embedding(doc)
r, err := o.Embedding(context.TODO(), doc)
if err != nil {
return nil, nil, err
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/friday/keywords.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@
package friday

import (
"context"
"strings"

"github.com/basenana/friday/pkg/llm/prompts"
)

func (f *Friday) Keywords(content string) (keywords []string, err error) {
func (f *Friday) Keywords(ctx context.Context, content string) (keywords []string, err error) {
prompt := prompts.NewKeywordsPrompt()

answers, err := f.LLM.Chat(prompt, map[string]string{"context": content})
answers, err := f.LLM.Chat(ctx, prompt, map[string]string{"context": content})
if err != nil {
return []string{}, err
}
Expand Down
8 changes: 5 additions & 3 deletions pkg/friday/keywords_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package friday

import (
"context"

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"

Expand All @@ -37,7 +39,7 @@ var _ = Describe("TestKeywords", func() {

Context("keywords", func() {
It("keywords should be succeed", func() {
keywords, err := loFriday.Keywords("test")
keywords, err := loFriday.Keywords(context.TODO(), "test")
Expect(err).Should(BeNil())
Expect(keywords).Should(Equal([]string{"a", "b", "c"}))
})
Expand All @@ -48,10 +50,10 @@ type FakeKeyWordsLLM struct{}

var _ llm.LLM = &FakeKeyWordsLLM{}

func (f FakeKeyWordsLLM) Completion(prompt prompts.PromptTemplate, parameters map[string]string) ([]string, error) {
func (f FakeKeyWordsLLM) Completion(ctx context.Context, prompt prompts.PromptTemplate, parameters map[string]string) ([]string, error) {
return []string{}, nil
}

func (f FakeKeyWordsLLM) Chat(prompt prompts.PromptTemplate, parameters map[string]string) ([]string, error) {
func (f FakeKeyWordsLLM) Chat(ctx context.Context, prompt prompts.PromptTemplate, parameters map[string]string) ([]string, error) {
return []string{"a, b, c"}, nil
}
5 changes: 3 additions & 2 deletions pkg/friday/question.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,21 @@
package friday

import (
"context"
"fmt"
"strings"

"github.com/basenana/friday/pkg/llm/prompts"
)

func (f *Friday) Question(q string) (string, error) {
func (f *Friday) Question(ctx context.Context, q string) (string, error) {
prompt := prompts.NewQuestionPrompt()
c, err := f.searchDocs(q)
if err != nil {
return "", err
}
if f.LLM != nil {
ans, err := f.LLM.Chat(prompt, map[string]string{
ans, err := f.LLM.Chat(ctx, prompt, map[string]string{
"context": c,
"question": q,
})
Expand Down
8 changes: 5 additions & 3 deletions pkg/friday/question_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package friday

import (
"context"

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"

Expand Down Expand Up @@ -45,7 +47,7 @@ var _ = Describe("TestQuestion", func() {

Context("question", func() {
It("question should be succeed", func() {
ans, err := loFriday.Question("I am a question")
ans, err := loFriday.Question(context.TODO(), "I am a question")
Expect(err).Should(BeNil())
Expect(ans).Should(Equal("I am an answer"))
})
Expand Down Expand Up @@ -93,10 +95,10 @@ type FakeQuestionLLM struct{}

var _ llm.LLM = &FakeQuestionLLM{}

func (f FakeQuestionLLM) Completion(prompt prompts.PromptTemplate, parameters map[string]string) ([]string, error) {
func (f FakeQuestionLLM) Completion(ctx context.Context, prompt prompts.PromptTemplate, parameters map[string]string) ([]string, error) {
return []string{}, nil
}

func (f FakeQuestionLLM) Chat(prompt prompts.PromptTemplate, parameters map[string]string) ([]string, error) {
func (f FakeQuestionLLM) Chat(ctx context.Context, prompt prompts.PromptTemplate, parameters map[string]string) ([]string, error) {
return []string{"I am an answer"}, nil
}
20 changes: 11 additions & 9 deletions pkg/friday/summary.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@
package friday

import (
"context"

"github.com/basenana/friday/pkg/friday/summary"
"github.com/basenana/friday/pkg/models"
"github.com/basenana/friday/pkg/utils/files"
)

func (f *Friday) Summary(elements []models.Element, summaryType summary.SummaryType) (map[string]string, error) {
func (f *Friday) Summary(ctx context.Context, elements []models.Element, summaryType summary.SummaryType) (map[string]string, error) {
result := make(map[string]string)
s := summary.NewSummary(f.LLM, 0)
s := summary.NewSummary(f.LLM, f.LimitToken)

docs := make(map[string][]string)
for _, element := range elements {
Expand All @@ -35,7 +37,7 @@ func (f *Friday) Summary(elements []models.Element, summaryType summary.SummaryT
}
}
for source, doc := range docs {
summaryOfFile, err := s.Summary(doc, summaryType)
summaryOfFile, err := s.Summary(ctx, doc, summaryType)
if err != nil {
return nil, err
}
Expand All @@ -45,12 +47,12 @@ func (f *Friday) Summary(elements []models.Element, summaryType summary.SummaryT
return result, nil
}

func (f *Friday) SummaryFromFile(file models.File, summaryType summary.SummaryType) (map[string]string, error) {
s := summary.NewSummary(f.LLM, 0)
func (f *Friday) SummaryFromFile(ctx context.Context, file models.File, summaryType summary.SummaryType) (map[string]string, error) {
s := summary.NewSummary(f.LLM, f.LimitToken)
// split doc
docs := f.Spliter.Split(file.Content)
// summary
summaryOfFile, err := s.Summary(docs, summaryType)
summaryOfFile, err := s.Summary(ctx, docs, summaryType)
if err != nil {
return nil, err
}
Expand All @@ -59,8 +61,8 @@ func (f *Friday) SummaryFromFile(file models.File, summaryType summary.SummaryTy
}, err
}

func (f *Friday) SummaryFromOriginFile(ps string, summaryType summary.SummaryType) (map[string]string, error) {
s := summary.NewSummary(f.LLM, 0)
func (f *Friday) SummaryFromOriginFile(ctx context.Context, ps string, summaryType summary.SummaryType) (map[string]string, error) {
s := summary.NewSummary(f.LLM, f.LimitToken)
fs, err := files.ReadFiles(ps)
if err != nil {
return nil, err
Expand All @@ -70,7 +72,7 @@ func (f *Friday) SummaryFromOriginFile(ps string, summaryType summary.SummaryTyp
for name, file := range fs {
// split doc
subDocs := f.Spliter.Split(file)
summaryOfFile, err := s.Summary(subDocs, summaryType)
summaryOfFile, err := s.Summary(ctx, subDocs, summaryType)
if err != nil {
return nil, err
}
Expand Down
17 changes: 9 additions & 8 deletions pkg/friday/summary/map-reduce.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@
package summary

import (
"context"
"fmt"
"strings"

"github.com/basenana/friday/pkg/llm/prompts"
"github.com/basenana/friday/pkg/utils/files"
)

func (s *Summary) MapReduce(docs []string) (summary string, err error) {
func (s *Summary) MapReduce(ctx context.Context, docs []string) (summary string, err error) {
// map
splitedSummaries, err := s.mapSummaries(docs)
splitedSummaries, err := s.mapSummaries(ctx, docs)
if err != nil {
return "", err
}
Expand All @@ -38,7 +39,7 @@ func (s *Summary) MapReduce(docs []string) (summary string, err error) {
}

// reduce
return s.reduce(splitedSummaries)
return s.reduce(ctx, splitedSummaries)
}

func (s *Summary) splitDocs(p prompts.PromptTemplate, docs []string) ([][]string, error) {
Expand Down Expand Up @@ -77,7 +78,7 @@ func (s *Summary) getLength(p prompts.PromptTemplate, docs []string) (length int
return length, nil
}

func (s *Summary) mapSummaries(docs []string) ([]string, error) {
func (s *Summary) mapSummaries(ctx context.Context, docs []string) ([]string, error) {
newDocs, err := s.splitDocs(s.summaryPrompt, docs)
if err != nil {
return nil, err
Expand All @@ -86,7 +87,7 @@ func (s *Summary) mapSummaries(docs []string) ([]string, error) {

splitedSummaries := []string{}
for _, splitedDocs := range newDocs {
d, err := s.Stuff(splitedDocs)
d, err := s.Stuff(ctx, splitedDocs)
if err != nil {
return nil, err
}
Expand All @@ -95,15 +96,15 @@ func (s *Summary) mapSummaries(docs []string) ([]string, error) {
return splitedSummaries, nil
}

func (s *Summary) reduce(summaries []string) (summary string, err error) {
func (s *Summary) reduce(ctx context.Context, summaries []string) (summary string, err error) {
newSummaries, err := s.splitDocs(s.combinePrompt, summaries)
if err != nil {
return "", err
}
combinedSummaries := []string{}
for _, subSummaries := range newSummaries {
subSummary := strings.Join(subSummaries, "\n")
res, err := s.llm.Chat(s.combinePrompt, map[string]string{"context": subSummary})
res, err := s.llm.Chat(ctx, s.combinePrompt, map[string]string{"context": subSummary})
if err != nil {
return "", err
}
Expand All @@ -113,5 +114,5 @@ func (s *Summary) reduce(summaries []string) (summary string, err error) {
if len(combinedSummaries) == 1 {
return combinedSummaries[0], nil
}
return s.reduce(combinedSummaries)
return s.reduce(ctx, combinedSummaries)
}
Loading
Loading