From a3bf21e337fc0f91c638707eb7d10e7931e49a60 Mon Sep 17 00:00:00 2001 From: zwwhdls Date: Wed, 20 Mar 2024 22:42:20 +0800 Subject: [PATCH] update: support gemini chat Signed-off-by: zwwhdls --- pkg/friday/keywords_test.go | 12 +++++ pkg/friday/question.go | 82 +++++++++++++++++++++----------- pkg/friday/question_test.go | 12 +++++ pkg/friday/summary_test.go | 12 +++++ pkg/llm/client/gemini/chat.go | 59 ++++++++++++++++------- pkg/llm/client/glm-6b/client.go | 12 +++++ pkg/llm/client/openai/v1/chat.go | 55 ++++++++++++++------- pkg/llm/llm.go | 4 ++ 8 files changed, 186 insertions(+), 62 deletions(-) diff --git a/pkg/friday/keywords_test.go b/pkg/friday/keywords_test.go index 4a71ff3..1dcbe0f 100644 --- a/pkg/friday/keywords_test.go +++ b/pkg/friday/keywords_test.go @@ -51,6 +51,18 @@ type FakeKeyWordsLLM struct{} var _ llm.LLM = &FakeKeyWordsLLM{} +func (f FakeKeyWordsLLM) GetUserModel() string { + return "user" +} + +func (f FakeKeyWordsLLM) GetSystemModel() string { + return "system" +} + +func (f FakeKeyWordsLLM) GetAssistantModel() string { + return "assistant" +} + func (f FakeKeyWordsLLM) Completion(ctx context.Context, prompt prompts.PromptTemplate, parameters map[string]string) ([]string, map[string]int, error) { return []string{"a, b, c"}, nil, nil } diff --git a/pkg/friday/question.go b/pkg/friday/question.go index 070aff0..77799da 100644 --- a/pkg/friday/question.go +++ b/pkg/friday/question.go @@ -25,6 +25,8 @@ import ( "github.com/basenana/friday/pkg/models" ) +const remainHistoryNum = 3 // must be odd + func (f *Friday) History(history []map[string]string) *Friday { f.statement.history = history return f @@ -54,7 +56,7 @@ func (f *Friday) Chat(res *ChatState) *Friday { // search for docs questions := "" for _, d := range f.statement.history { - if d["role"] == "user" { + if d["role"] == f.LLM.GetUserModel() { questions = fmt.Sprintf("%s\n%s", questions, d["content"]) } } @@ -68,8 +70,12 @@ func (f *Friday) Chat(res *ChatState) *Friday { } f.statement.history = append([]map[string]string{ { - "role": "system", - "content": fmt.Sprintf("基于以下已知信息,简洁和专业的来回答用户的问题。答案请使用中文。 \n\n已知内容: %s", f.statement.info), + "role": f.LLM.GetSystemModel(), + "content": fmt.Sprintf("基于以下已知信息,简洁和专业的来回答用户的问题。答案请使用中文。 \n\n已知内容: %s\n", f.statement.info), + }, + { + "role": f.LLM.GetAssistantModel(), + "content": "", }, }, f.statement.history...) @@ -82,46 +88,67 @@ func (f *Friday) chat(res *ChatState) *Friday { return f } var ( - tokens = map[string]int{} dialogues = []map[string]string{} ) // If the number of dialogue rounds exceeds 2 rounds, should conclude it. - if len(f.statement.history) >= 5 { - sumDialogue := make([]map[string]string, 0, len(f.statement.history)) + if len(f.statement.history) >= remainHistoryNum { + sumDialogue := make([]map[string]string, len(f.statement.history)) copy(sumDialogue, f.statement.history) - sumDialogue = append(sumDialogue, map[string]string{ - "role": "system", + sumDialogue[len(sumDialogue)-1] = map[string]string{ + "role": f.LLM.GetSystemModel(), "content": "简要总结一下对话内容,用作后续的上下文提示 prompt,控制在 200 字以内", - }) + } var ( sumBuf = make(chan map[string]string) sum = make(map[string]string) - usage = make(map[string]int) err error ) - defer close(sumBuf) go func() { - usage, err = f.LLM.Chat(f.statement.context, false, sumDialogue, sumBuf) + defer close(sumBuf) + _, err = f.LLM.Chat(f.statement.context, false, sumDialogue, sumBuf) + if err != nil { + f.Error = err + return + } }() - if err != nil { - f.Error = err - return f - } - tokens = mergeTokens(usage, tokens) select { case <-f.statement.context.Done(): return f case sum = <-sumBuf: // add context prompt for dialogue - dialogues = append(dialogues, []map[string]string{ - f.statement.history[0], - { - "role": "system", - "content": fmt.Sprintf("这是历史聊天总结作为前情提要:%s", sum["content"]), - }, - }...) - dialogues = append(dialogues, f.statement.history[len(f.statement.history)-5:len(f.statement.history)]...) + if f.statement.query != nil { + // there has been ingest info, combine them. + dialogues = []map[string]string{ + { + "role": f.LLM.GetSystemModel(), + "content": fmt.Sprintf( + "%s\n%s", + f.statement.history[0]["content"], + fmt.Sprintf("这是历史聊天总结作为前情提要:%s\n", sum["content"]), + ), + }, + { + "role": f.LLM.GetAssistantModel(), + "content": "", + }, + } + } else { + dialogues = []map[string]string{ + { + "role": f.LLM.GetSystemModel(), + "content": fmt.Sprintf("这是历史聊天总结作为前情提要:%s", sum["content"]), + }, + { + "role": f.LLM.GetAssistantModel(), + "content": "", + }, + } + } + dialogues = append(dialogues, f.statement.history[len(f.statement.history)-remainHistoryNum:len(f.statement.history)]...) + } + if f.Error != nil { + return f } } else { dialogues = make([]map[string]string, len(f.statement.history)) @@ -129,14 +156,11 @@ func (f *Friday) chat(res *ChatState) *Friday { } // go for llm - usage, err := f.LLM.Chat(f.statement.context, true, dialogues, res.Response) + _, err := f.LLM.Chat(f.statement.context, true, dialogues, res.Response) if err != nil { f.Error = err return f } - tokens = mergeTokens(tokens, usage) - - res.Tokens = tokens return f } diff --git a/pkg/friday/question_test.go b/pkg/friday/question_test.go index fd4149a..5ed0811 100644 --- a/pkg/friday/question_test.go +++ b/pkg/friday/question_test.go @@ -236,6 +236,18 @@ type FakeQuestionLLM struct{} var _ llm.LLM = &FakeQuestionLLM{} +func (f FakeQuestionLLM) GetUserModel() string { + return "user" +} + +func (f FakeQuestionLLM) GetSystemModel() string { + return "system" +} + +func (f FakeQuestionLLM) GetAssistantModel() string { + return "assistant" +} + func (f FakeQuestionLLM) Completion(ctx context.Context, prompt prompts.PromptTemplate, parameters map[string]string) ([]string, map[string]int, error) { return []string{"I am an answer"}, nil, nil } diff --git a/pkg/friday/summary_test.go b/pkg/friday/summary_test.go index 42ddd63..aba9a57 100644 --- a/pkg/friday/summary_test.go +++ b/pkg/friday/summary_test.go @@ -158,6 +158,18 @@ type FakeSummaryLLM struct{} var _ llm.LLM = &FakeSummaryLLM{} +func (f FakeSummaryLLM) GetUserModel() string { + return "user" +} + +func (f FakeSummaryLLM) GetSystemModel() string { + return "system" +} + +func (f FakeSummaryLLM) GetAssistantModel() string { + return "assistant" +} + func (f FakeSummaryLLM) Completion(ctx context.Context, prompt prompts.PromptTemplate, parameters map[string]string) ([]string, map[string]int, error) { return []string{"a b c"}, nil, nil } diff --git a/pkg/llm/client/gemini/chat.go b/pkg/llm/client/gemini/chat.go index 5661b5c..2afb166 100644 --- a/pkg/llm/client/gemini/chat.go +++ b/pkg/llm/client/gemini/chat.go @@ -20,10 +20,27 @@ import ( "context" "encoding/json" "fmt" + "strings" ) +func (g *Gemini) GetUserModel() string { + return "user" +} + +func (g *Gemini) GetSystemModel() string { + return "user" +} + +func (g *Gemini) GetAssistantModel() string { + return "model" +} + func (g *Gemini) Chat(ctx context.Context, stream bool, history []map[string]string, answers chan<- map[string]string) (tokens map[string]int, err error) { - path := fmt.Sprintf("v1beta/models/%s:streamGenerateContent", *g.conf.Model) + var path string + path = fmt.Sprintf("v1beta/models/%s:generateContent", *g.conf.Model) + if stream { + path = fmt.Sprintf("v1beta/models/%s:streamGenerateContent", *g.conf.Model) + } contents := make([]map[string]any, 0) for _, hs := range history { @@ -39,25 +56,35 @@ func (g *Gemini) Chat(ctx context.Context, stream bool, history []map[string]str go func() { defer close(buf) err = g.request(ctx, stream, path, "POST", map[string]any{"contents": contents}, buf) + if err != nil { + return + } }() - if err != nil { - return - } for line := range buf { - var res ChatResult - err = json.Unmarshal(line, &res) - if err != nil { - return nil, err - } - if len(res.Candidates) == 0 && res.PromptFeedback.BlockReason != "" { - g.log.Errorf("gemini response: %s ", string(line)) - return nil, fmt.Errorf("gemini api block because of %s", res.PromptFeedback.BlockReason) - } ans := make(map[string]string) - for _, c := range res.Candidates { - for _, t := range c.Content.Parts { - ans[c.Content.Role] = t.Text + l := strings.TrimSpace(string(line)) + if stream { + if !strings.HasPrefix(l, "\"text\"") { + continue + } + // it should be: "text": "xxx" + ans["content"] = l[9 : len(l)-2] + } else { + var res ChatResult + err = json.Unmarshal(line, &res) + if err != nil { + return nil, err + } + if len(res.Candidates) == 0 && res.PromptFeedback.BlockReason != "" { + g.log.Errorf("gemini response: %s ", string(line)) + return nil, fmt.Errorf("gemini api block because of %s", res.PromptFeedback.BlockReason) + } + for _, c := range res.Candidates { + for _, t := range c.Content.Parts { + ans["role"] = c.Content.Role + ans["content"] = t.Text + } } } select { diff --git a/pkg/llm/client/glm-6b/client.go b/pkg/llm/client/glm-6b/client.go index fc40acb..1dbb6c1 100644 --- a/pkg/llm/client/glm-6b/client.go +++ b/pkg/llm/client/glm-6b/client.go @@ -76,6 +76,18 @@ type CompletionResult struct { Time string `json:"time"` } +func (o *GLM) GetUserModel() string { + return "user" +} + +func (o *GLM) GetSystemModel() string { + return "system" +} + +func (o *GLM) GetAssistantModel() string { + return "assistant" +} + func (o *GLM) Completion(ctx context.Context, prompt prompts.PromptTemplate, parameters map[string]string) ([]string, map[string]int, error) { path := "" diff --git a/pkg/llm/client/openai/v1/chat.go b/pkg/llm/client/openai/v1/chat.go index c166017..47fc8f0 100644 --- a/pkg/llm/client/openai/v1/chat.go +++ b/pkg/llm/client/openai/v1/chat.go @@ -52,11 +52,19 @@ type ChatStreamChoice struct { FinishReason string `json:"finish_reason,omitempty"` } -func (o *OpenAIV1) Chat(ctx context.Context, stream bool, history []map[string]string, answers chan<- map[string]string) (map[string]int, error) { - return o.chat(ctx, stream, history, answers) +func (o *OpenAIV1) GetUserModel() string { + return "user" } -func (o *OpenAIV1) chat(ctx context.Context, stream bool, history []map[string]string, resp chan<- map[string]string) (tokens map[string]int, err error) { +func (o *OpenAIV1) GetSystemModel() string { + return "system" +} + +func (o *OpenAIV1) GetAssistantModel() string { + return "assistant" +} + +func (o *OpenAIV1) Chat(ctx context.Context, stream bool, history []map[string]string, resp chan<- map[string]string) (tokens map[string]int, err error) { path := "v1/chat/completions" data := map[string]interface{}{ @@ -73,30 +81,43 @@ func (o *OpenAIV1) chat(ctx context.Context, stream bool, history []map[string]s buf := make(chan []byte) go func() { + defer close(buf) err = o.request(ctx, stream, path, "POST", data, buf) - close(buf) + if err != nil { + return + } }() - if err != nil { - return - } for line := range buf { - var res ChatStreamResult - if !strings.HasPrefix(string(line), "data:") || strings.Contains(string(line), "data: [DONE]") { - continue - } - l := string(line)[6:] - err = json.Unmarshal([]byte(l), &res) - if err != nil { - err = fmt.Errorf("cannot marshal msg: %s, err: %v", line, err) - return + var delta map[string]string + if stream { + var res ChatStreamResult + if !strings.HasPrefix(string(line), "data:") || strings.Contains(string(line), "data: [DONE]") { + continue + } + // it should be: data: xxx + l := string(line)[6:] + err = json.Unmarshal([]byte(l), &res) + if err != nil { + err = fmt.Errorf("cannot marshal msg: %s, err: %v", line, err) + return + } + delta = res.Choices[0].Delta + } else { + var res ChatResult + err = json.Unmarshal(line, &res) + if err != nil { + err = fmt.Errorf("cannot marshal msg: %s, err: %v", line, err) + return + } + delta = res.Choices[0].Message } select { case <-ctx.Done(): err = fmt.Errorf("context timeout in openai chat") return - case resp <- res.Choices[0].Delta: + case resp <- delta: continue } } diff --git a/pkg/llm/llm.go b/pkg/llm/llm.go index 2463e43..651bd36 100644 --- a/pkg/llm/llm.go +++ b/pkg/llm/llm.go @@ -23,6 +23,10 @@ import ( ) type LLM interface { + GetUserModel() string + GetSystemModel() string + GetAssistantModel() string + // Completion chat with llm just once Completion(ctx context.Context, prompt prompts.PromptTemplate, parameters map[string]string) (answers []string, tokens map[string]int, err error) /*