Skip to content

Commit

Permalink
Merge pull request #41 from basenana/feature/gemini_chat
Browse files Browse the repository at this point in the history
update: support gemini chat
  • Loading branch information
zwwhdls committed Mar 20, 2024
2 parents a34e28c + a3bf21e commit 3c9cf98
Show file tree
Hide file tree
Showing 8 changed files with 186 additions and 62 deletions.
12 changes: 12 additions & 0 deletions pkg/friday/keywords_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,18 @@ type FakeKeyWordsLLM struct{}

var _ llm.LLM = &FakeKeyWordsLLM{}

func (f FakeKeyWordsLLM) GetUserModel() string {
return "user"
}

func (f FakeKeyWordsLLM) GetSystemModel() string {
return "system"
}

func (f FakeKeyWordsLLM) GetAssistantModel() string {
return "assistant"
}

func (f FakeKeyWordsLLM) Completion(ctx context.Context, prompt prompts.PromptTemplate, parameters map[string]string) ([]string, map[string]int, error) {
return []string{"a, b, c"}, nil, nil
}
Expand Down
82 changes: 53 additions & 29 deletions pkg/friday/question.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import (
"github.com/basenana/friday/pkg/models"
)

const remainHistoryNum = 3 // must be odd

func (f *Friday) History(history []map[string]string) *Friday {
f.statement.history = history
return f
Expand Down Expand Up @@ -54,7 +56,7 @@ func (f *Friday) Chat(res *ChatState) *Friday {
// search for docs
questions := ""
for _, d := range f.statement.history {
if d["role"] == "user" {
if d["role"] == f.LLM.GetUserModel() {
questions = fmt.Sprintf("%s\n%s", questions, d["content"])
}
}
Expand All @@ -68,8 +70,12 @@ func (f *Friday) Chat(res *ChatState) *Friday {
}
f.statement.history = append([]map[string]string{
{
"role": "system",
"content": fmt.Sprintf("基于以下已知信息,简洁和专业的来回答用户的问题。答案请使用中文。 \n\n已知内容: %s", f.statement.info),
"role": f.LLM.GetSystemModel(),
"content": fmt.Sprintf("基于以下已知信息,简洁和专业的来回答用户的问题。答案请使用中文。 \n\n已知内容: %s\n", f.statement.info),
},
{
"role": f.LLM.GetAssistantModel(),
"content": "",
},
}, f.statement.history...)

Expand All @@ -82,61 +88,79 @@ func (f *Friday) chat(res *ChatState) *Friday {
return f
}
var (
tokens = map[string]int{}
dialogues = []map[string]string{}
)

// If the number of dialogue rounds exceeds 2 rounds, should conclude it.
if len(f.statement.history) >= 5 {
sumDialogue := make([]map[string]string, 0, len(f.statement.history))
if len(f.statement.history) >= remainHistoryNum {
sumDialogue := make([]map[string]string, len(f.statement.history))
copy(sumDialogue, f.statement.history)
sumDialogue = append(sumDialogue, map[string]string{
"role": "system",
sumDialogue[len(sumDialogue)-1] = map[string]string{
"role": f.LLM.GetSystemModel(),
"content": "简要总结一下对话内容,用作后续的上下文提示 prompt,控制在 200 字以内",
})
}
var (
sumBuf = make(chan map[string]string)
sum = make(map[string]string)
usage = make(map[string]int)
err error
)
defer close(sumBuf)
go func() {
usage, err = f.LLM.Chat(f.statement.context, false, sumDialogue, sumBuf)
defer close(sumBuf)
_, err = f.LLM.Chat(f.statement.context, false, sumDialogue, sumBuf)
if err != nil {
f.Error = err
return
}
}()
if err != nil {
f.Error = err
return f
}
tokens = mergeTokens(usage, tokens)
select {
case <-f.statement.context.Done():
return f
case sum = <-sumBuf:
// add context prompt for dialogue
dialogues = append(dialogues, []map[string]string{
f.statement.history[0],
{
"role": "system",
"content": fmt.Sprintf("这是历史聊天总结作为前情提要:%s", sum["content"]),
},
}...)
dialogues = append(dialogues, f.statement.history[len(f.statement.history)-5:len(f.statement.history)]...)
if f.statement.query != nil {
// there has been ingest info, combine them.
dialogues = []map[string]string{
{
"role": f.LLM.GetSystemModel(),
"content": fmt.Sprintf(
"%s\n%s",
f.statement.history[0]["content"],
fmt.Sprintf("这是历史聊天总结作为前情提要:%s\n", sum["content"]),
),
},
{
"role": f.LLM.GetAssistantModel(),
"content": "",
},
}
} else {
dialogues = []map[string]string{
{
"role": f.LLM.GetSystemModel(),
"content": fmt.Sprintf("这是历史聊天总结作为前情提要:%s", sum["content"]),
},
{
"role": f.LLM.GetAssistantModel(),
"content": "",
},
}
}
dialogues = append(dialogues, f.statement.history[len(f.statement.history)-remainHistoryNum:len(f.statement.history)]...)
}
if f.Error != nil {
return f
}
} else {
dialogues = make([]map[string]string, len(f.statement.history))
copy(dialogues, f.statement.history)
}

// go for llm
usage, err := f.LLM.Chat(f.statement.context, true, dialogues, res.Response)
_, err := f.LLM.Chat(f.statement.context, true, dialogues, res.Response)
if err != nil {
f.Error = err
return f
}
tokens = mergeTokens(tokens, usage)

res.Tokens = tokens
return f
}

Expand Down
12 changes: 12 additions & 0 deletions pkg/friday/question_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,18 @@ type FakeQuestionLLM struct{}

var _ llm.LLM = &FakeQuestionLLM{}

func (f FakeQuestionLLM) GetUserModel() string {
return "user"
}

func (f FakeQuestionLLM) GetSystemModel() string {
return "system"
}

func (f FakeQuestionLLM) GetAssistantModel() string {
return "assistant"
}

func (f FakeQuestionLLM) Completion(ctx context.Context, prompt prompts.PromptTemplate, parameters map[string]string) ([]string, map[string]int, error) {
return []string{"I am an answer"}, nil, nil
}
Expand Down
12 changes: 12 additions & 0 deletions pkg/friday/summary_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,18 @@ type FakeSummaryLLM struct{}

var _ llm.LLM = &FakeSummaryLLM{}

func (f FakeSummaryLLM) GetUserModel() string {
return "user"
}

func (f FakeSummaryLLM) GetSystemModel() string {
return "system"
}

func (f FakeSummaryLLM) GetAssistantModel() string {
return "assistant"
}

func (f FakeSummaryLLM) Completion(ctx context.Context, prompt prompts.PromptTemplate, parameters map[string]string) ([]string, map[string]int, error) {
return []string{"a b c"}, nil, nil
}
Expand Down
59 changes: 43 additions & 16 deletions pkg/llm/client/gemini/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,27 @@ import (
"context"
"encoding/json"
"fmt"
"strings"
)

func (g *Gemini) GetUserModel() string {
return "user"
}

func (g *Gemini) GetSystemModel() string {
return "user"
}

func (g *Gemini) GetAssistantModel() string {
return "model"
}

func (g *Gemini) Chat(ctx context.Context, stream bool, history []map[string]string, answers chan<- map[string]string) (tokens map[string]int, err error) {
path := fmt.Sprintf("v1beta/models/%s:streamGenerateContent", *g.conf.Model)
var path string
path = fmt.Sprintf("v1beta/models/%s:generateContent", *g.conf.Model)
if stream {
path = fmt.Sprintf("v1beta/models/%s:streamGenerateContent", *g.conf.Model)
}

contents := make([]map[string]any, 0)
for _, hs := range history {
Expand All @@ -39,25 +56,35 @@ func (g *Gemini) Chat(ctx context.Context, stream bool, history []map[string]str
go func() {
defer close(buf)
err = g.request(ctx, stream, path, "POST", map[string]any{"contents": contents}, buf)
if err != nil {
return
}
}()
if err != nil {
return
}

for line := range buf {
var res ChatResult
err = json.Unmarshal(line, &res)
if err != nil {
return nil, err
}
if len(res.Candidates) == 0 && res.PromptFeedback.BlockReason != "" {
g.log.Errorf("gemini response: %s ", string(line))
return nil, fmt.Errorf("gemini api block because of %s", res.PromptFeedback.BlockReason)
}
ans := make(map[string]string)
for _, c := range res.Candidates {
for _, t := range c.Content.Parts {
ans[c.Content.Role] = t.Text
l := strings.TrimSpace(string(line))
if stream {
if !strings.HasPrefix(l, "\"text\"") {
continue
}
// it should be: "text": "xxx"
ans["content"] = l[9 : len(l)-2]
} else {
var res ChatResult
err = json.Unmarshal(line, &res)
if err != nil {
return nil, err
}
if len(res.Candidates) == 0 && res.PromptFeedback.BlockReason != "" {
g.log.Errorf("gemini response: %s ", string(line))
return nil, fmt.Errorf("gemini api block because of %s", res.PromptFeedback.BlockReason)
}
for _, c := range res.Candidates {
for _, t := range c.Content.Parts {
ans["role"] = c.Content.Role
ans["content"] = t.Text
}
}
}
select {
Expand Down
12 changes: 12 additions & 0 deletions pkg/llm/client/glm-6b/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,18 @@ type CompletionResult struct {
Time string `json:"time"`
}

func (o *GLM) GetUserModel() string {
return "user"
}

func (o *GLM) GetSystemModel() string {
return "system"
}

func (o *GLM) GetAssistantModel() string {
return "assistant"
}

func (o *GLM) Completion(ctx context.Context, prompt prompts.PromptTemplate, parameters map[string]string) ([]string, map[string]int, error) {
path := ""

Expand Down
55 changes: 38 additions & 17 deletions pkg/llm/client/openai/v1/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,19 @@ type ChatStreamChoice struct {
FinishReason string `json:"finish_reason,omitempty"`
}

func (o *OpenAIV1) Chat(ctx context.Context, stream bool, history []map[string]string, answers chan<- map[string]string) (map[string]int, error) {
return o.chat(ctx, stream, history, answers)
func (o *OpenAIV1) GetUserModel() string {
return "user"
}

func (o *OpenAIV1) chat(ctx context.Context, stream bool, history []map[string]string, resp chan<- map[string]string) (tokens map[string]int, err error) {
func (o *OpenAIV1) GetSystemModel() string {
return "system"
}

func (o *OpenAIV1) GetAssistantModel() string {
return "assistant"
}

func (o *OpenAIV1) Chat(ctx context.Context, stream bool, history []map[string]string, resp chan<- map[string]string) (tokens map[string]int, err error) {
path := "v1/chat/completions"

data := map[string]interface{}{
Expand All @@ -73,30 +81,43 @@ func (o *OpenAIV1) chat(ctx context.Context, stream bool, history []map[string]s

buf := make(chan []byte)
go func() {
defer close(buf)
err = o.request(ctx, stream, path, "POST", data, buf)
close(buf)
if err != nil {
return
}
}()

if err != nil {
return
}
for line := range buf {
var res ChatStreamResult
if !strings.HasPrefix(string(line), "data:") || strings.Contains(string(line), "data: [DONE]") {
continue
}
l := string(line)[6:]
err = json.Unmarshal([]byte(l), &res)
if err != nil {
err = fmt.Errorf("cannot marshal msg: %s, err: %v", line, err)
return
var delta map[string]string
if stream {
var res ChatStreamResult
if !strings.HasPrefix(string(line), "data:") || strings.Contains(string(line), "data: [DONE]") {
continue
}
// it should be: data: xxx
l := string(line)[6:]
err = json.Unmarshal([]byte(l), &res)
if err != nil {
err = fmt.Errorf("cannot marshal msg: %s, err: %v", line, err)
return
}
delta = res.Choices[0].Delta
} else {
var res ChatResult
err = json.Unmarshal(line, &res)
if err != nil {
err = fmt.Errorf("cannot marshal msg: %s, err: %v", line, err)
return
}
delta = res.Choices[0].Message
}

select {
case <-ctx.Done():
err = fmt.Errorf("context timeout in openai chat")
return
case resp <- res.Choices[0].Delta:
case resp <- delta:
continue
}
}
Expand Down
4 changes: 4 additions & 0 deletions pkg/llm/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ import (
)

type LLM interface {
GetUserModel() string
GetSystemModel() string
GetAssistantModel() string

// Completion chat with llm just once
Completion(ctx context.Context, prompt prompts.PromptTemplate, parameters map[string]string) (answers []string, tokens map[string]int, err error)
/*
Expand Down

0 comments on commit 3c9cf98

Please sign in to comment.