diff --git a/cmd/apps/chat.go b/cmd/apps/chat.go index 76ffea2..0329b1e 100644 --- a/cmd/apps/chat.go +++ b/cmd/apps/chat.go @@ -67,7 +67,6 @@ func chat(dirId int64, history []map[string]string) error { } go func() { f = f.Chat(res) - close(resp) }() if f.Error != nil { return f.Error diff --git a/pkg/friday/question.go b/pkg/friday/question.go index 77799da..7f57a29 100644 --- a/pkg/friday/question.go +++ b/pkg/friday/question.go @@ -105,7 +105,6 @@ func (f *Friday) chat(res *ChatState) *Friday { err error ) go func() { - defer close(sumBuf) _, err = f.LLM.Chat(f.statement.context, false, sumDialogue, sumBuf) if err != nil { f.Error = err @@ -114,6 +113,7 @@ func (f *Friday) chat(res *ChatState) *Friday { }() select { case <-f.statement.context.Done(): + f.Error = errors.New("context canceled") return f case sum = <-sumBuf: // add context prompt for dialogue diff --git a/pkg/llm/client/gemini/chat.go b/pkg/llm/client/gemini/chat.go index 2afb166..1e20d1b 100644 --- a/pkg/llm/client/gemini/chat.go +++ b/pkg/llm/client/gemini/chat.go @@ -36,6 +36,7 @@ func (g *Gemini) GetAssistantModel() string { } func (g *Gemini) Chat(ctx context.Context, stream bool, history []map[string]string, answers chan<- map[string]string) (tokens map[string]int, err error) { + defer close(answers) var path string path = fmt.Sprintf("v1beta/models/%s:generateContent", *g.conf.Model) if stream { @@ -65,11 +66,15 @@ func (g *Gemini) Chat(ctx context.Context, stream bool, history []map[string]str ans := make(map[string]string) l := strings.TrimSpace(string(line)) if stream { - if !strings.HasPrefix(l, "\"text\"") { - continue + if l == "EOF" { + ans["content"] = "EOF" + } else { + if !strings.HasPrefix(l, "\"text\"") { + continue + } + // it should be: "text": "xxx" + ans["content"] = l[9 : len(l)-2] } - // it should be: "text": "xxx" - ans["content"] = l[9 : len(l)-2] } else { var res ChatResult err = json.Unmarshal(line, &res) diff --git a/pkg/llm/client/openai/v1/chat.go b/pkg/llm/client/openai/v1/chat.go index 47fc8f0..b1051e3 100644 --- a/pkg/llm/client/openai/v1/chat.go +++ b/pkg/llm/client/openai/v1/chat.go @@ -65,6 +65,7 @@ func (o *OpenAIV1) GetAssistantModel() string { } func (o *OpenAIV1) Chat(ctx context.Context, stream bool, history []map[string]string, resp chan<- map[string]string) (tokens map[string]int, err error) { + defer close(resp) path := "v1/chat/completions" data := map[string]interface{}{ @@ -92,17 +93,21 @@ func (o *OpenAIV1) Chat(ctx context.Context, stream bool, history []map[string]s var delta map[string]string if stream { var res ChatStreamResult - if !strings.HasPrefix(string(line), "data:") || strings.Contains(string(line), "data: [DONE]") { - continue + if string(line) == "EOF" { + delta = map[string]string{"content": "EOF"} + } else { + if !strings.HasPrefix(string(line), "data:") || strings.Contains(string(line), "data: [DONE]") { + continue + } + // it should be: data: xxx + l := string(line)[6:] + err = json.Unmarshal([]byte(l), &res) + if err != nil { + err = fmt.Errorf("cannot marshal msg: %s, err: %v", line, err) + return + } + delta = res.Choices[0].Delta } - // it should be: data: xxx - l := string(line)[6:] - err = json.Unmarshal([]byte(l), &res) - if err != nil { - err = fmt.Errorf("cannot marshal msg: %s, err: %v", line, err) - return - } - delta = res.Choices[0].Delta } else { var res ChatResult err = json.Unmarshal(line, &res)