Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support streaming and chain of agent #620

Merged
merged 1 commit into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apiserver/pkg/knowledgebase/knowledgebase.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func knowledgebase2model(ctx context.Context, c dynamic.Interface, obj *unstruct
CreationTimestamp: &creationtimestamp,
UpdateTimestamp: &condition.LastTransitionTime.Time,
// Embedder info
Embedder: &embedder,
Embedder: &embedder,
EmbedderType: &embedderType,
// Vector info
VectorStore: &generated.TypedObjectReference{
Expand Down
5 changes: 4 additions & 1 deletion controllers/base/application_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import (
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/controller/controllerutil"

agentv1alpha1 "github.com/kubeagi/arcadia/api/app-node/agent/v1alpha1"
chainv1alpha1 "github.com/kubeagi/arcadia/api/app-node/chain/v1alpha1"
arcadiav1alpha1 "github.com/kubeagi/arcadia/api/base/v1alpha1"
)

Expand Down Expand Up @@ -151,7 +153,8 @@ func (r *ApplicationReconciler) validateNodes(ctx context.Context, log logr.Logg
r.setCondition(app, app.Status.ErrorCondition("node should have ref.group setting")...)
return app, ctrl.Result{}, nil
}
if *group != "chain.arcadia.kubeagi.k8s.com.cn" && *group != "agent.arcadia.kubeagi.k8s.com.cn" {
// Only allow chain group or agent node as the ending node
if *group != chainv1alpha1.Group && (*group != agentv1alpha1.Group && node.Ref.Kind != "agent") {
r.setCondition(app, app.Status.ErrorCondition("ending node should be chain or agent")...)
return app, ctrl.Result{}, nil
}
Expand Down
17 changes: 14 additions & 3 deletions pkg/appruntime/agent/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"fmt"

"github.com/tmc/langchaingo/agents"
"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/tools"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand Down Expand Up @@ -83,8 +84,14 @@ func (p *Executor) Run(ctx context.Context, cli dynamic.Interface, args map[stri
}

// Initialize executor using langchaingo
options := agents.WithMaxIterations(instance.Spec.Options.MaxIterations)
executor, err := agents.Initialize(llm, allowedTools, agents.ZeroShotReactDescription, options)
executorOptions := func(o *agents.CreationOptions) {
agents.WithMaxIterations(instance.Spec.Options.MaxIterations)(o)
if needStream, ok := args["_need_stream"].(bool); ok && needStream {
streamHandler := StreamHandler{callbacks.SimpleHandler{}, args}
agents.WithCallbacksHandler(streamHandler)(o)
}
}
executor, err := agents.Initialize(llm, allowedTools, agents.ZeroShotReactDescription, executorOptions)
if err != nil {
return args, fmt.Errorf("failed to initialize executor: %w", err)
}
Expand All @@ -94,6 +101,10 @@ func (p *Executor) Run(ctx context.Context, cli dynamic.Interface, args map[stri
if err != nil {
return args, fmt.Errorf("error when call agent: %w", err)
}
args["_answer"] = response["output"]
klog.FromContext(ctx).V(5).Info("use agent, blocking out:", response["output"])
if err == nil {
args["_answer"] = response["output"]
return args, nil
}
return args, nil
}
49 changes: 49 additions & 0 deletions pkg/appruntime/agent/streamhandler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
Copyright 2024 KubeAGI.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package agent

import (
"context"
"fmt"

"github.com/tmc/langchaingo/callbacks"
"k8s.io/klog/v2"
)

// StreamHandler is a callback handler that prints to the standard output streaming.
type StreamHandler struct {
callbacks.SimpleHandler
args map[string]any
}

var _ callbacks.Handler = StreamHandler{}

func (handler StreamHandler) HandleStreamingFunc(ctx context.Context, chunk []byte) {
logger := klog.FromContext(ctx)
if _, ok := handler.args["_answer_stream"]; !ok {
logger.Info("no _answer_stream found, create a new one")
handler.args["_answer_stream"] = make(chan string)
}
streamChan, ok := handler.args["_answer_stream"].(chan string)
if !ok {
err := fmt.Errorf("answer_stream is not chan string, but %T", handler.args["_answer_stream"])
logger.Error(err, "answer_stream is not chan string")
return
}
logger.V(5).Info("stream out:" + string(chunk))
streamChan <- string(chunk)
}
10 changes: 5 additions & 5 deletions pkg/appruntime/app_runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,11 @@ func (a *Application) Run(ctx context.Context, cli dynamic.Interface, respStream
"question": input.Question,
"_answer_stream": respStream,
"_history": input.History,
"context": "",
// Use an empty context before run
"context": "",
}
if input.NeedStream {
out["_need_stream"] = true
}
visited := make(map[string]bool)
waitRunningNodes := list.New()
Expand All @@ -174,7 +178,6 @@ func (a *Application) Run(ctx context.Context, cli dynamic.Interface, respStream
for e := waitRunningNodes.Front(); e != nil; e = e.Next() {
e := e.Value.(base.Node)
if !visited[e.Name()] {
out["_need_stream"] = false
reWait := false
for _, n := range e.GetPrevNode() {
if !visited[n.Name()] {
Expand All @@ -186,9 +189,6 @@ func (a *Application) Run(ctx context.Context, cli dynamic.Interface, respStream
waitRunningNodes.PushBack(e)
continue
}
if a.EndingNode.Name() == e.Name() && input.NeedStream {
out["_need_stream"] = true
}
klog.FromContext(ctx).V(3).Info(fmt.Sprintf("try to run node:%s", e.Name()))
if out, err = e.Run(ctx, cli, out); err != nil {
return Output{}, fmt.Errorf("run node %s: %w", e.Name(), err)
Expand Down
34 changes: 16 additions & 18 deletions pkg/appruntime/chain/llmchain.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,33 +86,31 @@ func (l *LLMChain) Run(ctx context.Context, cli dynamic.Interface, args map[stri
return args, fmt.Errorf("can't convert obj to LLMChain: %w", err)
}
options := getChainOptions(instance.Spec.CommonChainConfig)

// Add the answer to the context if it's not empty
if args["_answer"] != nil {
klog.Infoln("get answer from upstream:", args["_answer"])
args["context"] = args["_answer"]
}
chain := chains.NewLLMChain(llm, prompt)
if history != nil {
chain.Memory = getMemory(llm, instance.Spec.Memory, history, "", "")
}
l.LLMChain = *chain

// Skip to do LLMChain again if we already has the answer, such as from agents
if args["_answer"] == nil {
var out string
if needStream, ok := args["_need_stream"].(bool); ok && needStream {
options = append(options, chains.WithStreamingFunc(stream(args)))
var out string
if needStream, ok := args["_need_stream"].(bool); ok && needStream {
0xff-dev marked this conversation as resolved.
Show resolved Hide resolved
options = append(options, chains.WithStreamingFunc(stream(args)))
out, err = chains.Predict(ctx, l.LLMChain, args, options...)
} else {
if len(options) > 0 {
out, err = chains.Predict(ctx, l.LLMChain, args, options...)
} else {
if len(options) > 0 {
out, err = chains.Predict(ctx, l.LLMChain, args, options...)
} else {
out, err = chains.Predict(ctx, l.LLMChain, args)
}
out, err = chains.Predict(ctx, l.LLMChain, args)
}
klog.FromContext(ctx).V(5).Info("use llmchain, blocking out:" + out)
if err == nil {
args["_answer"] = out
return args, nil
}
} else {
klog.Infoln("get answer from upstream:", args["_answer"])
}
klog.FromContext(ctx).V(5).Info("use llmchain, blocking out:" + out)
if err == nil {
args["_answer"] = out
return args, nil
}
return args, fmt.Errorf("llmchain run error: %w", err)
Expand Down
3 changes: 3 additions & 0 deletions pkg/appruntime/prompt/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ func (p *Prompt) Run(ctx context.Context, cli dynamic.Interface, args map[string
ps = append(ps, prompts.NewSystemMessagePromptTemplate(instance.Spec.SystemMessage, []string{}))
}
if instance.Spec.UserMessage != "" {
// Add the context by default, and leave it empty
// so we can add more contexts as needed in all agents/chains
instance.Spec.UserMessage = fmt.Sprintf("{{.context}}\n%s", instance.Spec.UserMessage)
ps = append(ps, prompts.NewHumanMessagePromptTemplate(instance.Spec.UserMessage, []string{"question"}))
}
template := prompts.NewChatPromptTemplate(ps)
Expand Down