Skip to content

Commit

Permalink
Merge pull request #620 from nkwangleiGIT/main
Browse files Browse the repository at this point in the history
feat: support streaming and chain of agent
  • Loading branch information
nkwangleiGIT committed Jan 23, 2024
2 parents cad071c + 2377775 commit 3beec19
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 28 deletions.
2 changes: 1 addition & 1 deletion apiserver/pkg/knowledgebase/knowledgebase.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func knowledgebase2model(ctx context.Context, c dynamic.Interface, obj *unstruct
CreationTimestamp: &creationtimestamp,
UpdateTimestamp: &condition.LastTransitionTime.Time,
// Embedder info
Embedder: &embedder,
Embedder: &embedder,
EmbedderType: &embedderType,
// Vector info
VectorStore: &generated.TypedObjectReference{
Expand Down
5 changes: 4 additions & 1 deletion controllers/base/application_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import (
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/controller/controllerutil"

agentv1alpha1 "github.com/kubeagi/arcadia/api/app-node/agent/v1alpha1"
chainv1alpha1 "github.com/kubeagi/arcadia/api/app-node/chain/v1alpha1"
arcadiav1alpha1 "github.com/kubeagi/arcadia/api/base/v1alpha1"
)

Expand Down Expand Up @@ -151,7 +153,8 @@ func (r *ApplicationReconciler) validateNodes(ctx context.Context, log logr.Logg
r.setCondition(app, app.Status.ErrorCondition("node should have ref.group setting")...)
return app, ctrl.Result{}, nil
}
if *group != "chain.arcadia.kubeagi.k8s.com.cn" && *group != "agent.arcadia.kubeagi.k8s.com.cn" {
// Only allow chain group or agent node as the ending node
if *group != chainv1alpha1.Group && (*group != agentv1alpha1.Group && node.Ref.Kind != "agent") {
r.setCondition(app, app.Status.ErrorCondition("ending node should be chain or agent")...)
return app, ctrl.Result{}, nil
}
Expand Down
17 changes: 14 additions & 3 deletions pkg/appruntime/agent/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"fmt"

"github.com/tmc/langchaingo/agents"
"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/tools"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand Down Expand Up @@ -83,8 +84,14 @@ func (p *Executor) Run(ctx context.Context, cli dynamic.Interface, args map[stri
}

// Initialize executor using langchaingo
options := agents.WithMaxIterations(instance.Spec.Options.MaxIterations)
executor, err := agents.Initialize(llm, allowedTools, agents.ZeroShotReactDescription, options)
executorOptions := func(o *agents.CreationOptions) {
agents.WithMaxIterations(instance.Spec.Options.MaxIterations)(o)
if needStream, ok := args["_need_stream"].(bool); ok && needStream {
streamHandler := StreamHandler{callbacks.SimpleHandler{}, args}
agents.WithCallbacksHandler(streamHandler)(o)
}
}
executor, err := agents.Initialize(llm, allowedTools, agents.ZeroShotReactDescription, executorOptions)
if err != nil {
return args, fmt.Errorf("failed to initialize executor: %w", err)
}
Expand All @@ -94,6 +101,10 @@ func (p *Executor) Run(ctx context.Context, cli dynamic.Interface, args map[stri
if err != nil {
return args, fmt.Errorf("error when call agent: %w", err)
}
args["_answer"] = response["output"]
klog.FromContext(ctx).V(5).Info("use agent, blocking out:", response["output"])
if err == nil {
args["_answer"] = response["output"]
return args, nil
}
return args, nil
}
49 changes: 49 additions & 0 deletions pkg/appruntime/agent/streamhandler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
Copyright 2024 KubeAGI.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package agent

import (
"context"
"fmt"

"github.com/tmc/langchaingo/callbacks"
"k8s.io/klog/v2"
)

// StreamHandler is a callback handler that prints to the standard output streaming.
type StreamHandler struct {
callbacks.SimpleHandler
args map[string]any
}

var _ callbacks.Handler = StreamHandler{}

func (handler StreamHandler) HandleStreamingFunc(ctx context.Context, chunk []byte) {
logger := klog.FromContext(ctx)
if _, ok := handler.args["_answer_stream"]; !ok {
logger.Info("no _answer_stream found, create a new one")
handler.args["_answer_stream"] = make(chan string)
}
streamChan, ok := handler.args["_answer_stream"].(chan string)
if !ok {
err := fmt.Errorf("answer_stream is not chan string, but %T", handler.args["_answer_stream"])
logger.Error(err, "answer_stream is not chan string")
return
}
logger.V(5).Info("stream out:" + string(chunk))
streamChan <- string(chunk)
}
10 changes: 5 additions & 5 deletions pkg/appruntime/app_runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,11 @@ func (a *Application) Run(ctx context.Context, cli dynamic.Interface, respStream
"question": input.Question,
"_answer_stream": respStream,
"_history": input.History,
"context": "",
// Use an empty context before run
"context": "",
}
if input.NeedStream {
out["_need_stream"] = true
}
visited := make(map[string]bool)
waitRunningNodes := list.New()
Expand All @@ -174,7 +178,6 @@ func (a *Application) Run(ctx context.Context, cli dynamic.Interface, respStream
for e := waitRunningNodes.Front(); e != nil; e = e.Next() {
e := e.Value.(base.Node)
if !visited[e.Name()] {
out["_need_stream"] = false
reWait := false
for _, n := range e.GetPrevNode() {
if !visited[n.Name()] {
Expand All @@ -186,9 +189,6 @@ func (a *Application) Run(ctx context.Context, cli dynamic.Interface, respStream
waitRunningNodes.PushBack(e)
continue
}
if a.EndingNode.Name() == e.Name() && input.NeedStream {
out["_need_stream"] = true
}
klog.FromContext(ctx).V(3).Info(fmt.Sprintf("try to run node:%s", e.Name()))
if out, err = e.Run(ctx, cli, out); err != nil {
return Output{}, fmt.Errorf("run node %s: %w", e.Name(), err)
Expand Down
34 changes: 16 additions & 18 deletions pkg/appruntime/chain/llmchain.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,33 +86,31 @@ func (l *LLMChain) Run(ctx context.Context, cli dynamic.Interface, args map[stri
return args, fmt.Errorf("can't convert obj to LLMChain: %w", err)
}
options := getChainOptions(instance.Spec.CommonChainConfig)

// Add the answer to the context if it's not empty
if args["_answer"] != nil {
klog.Infoln("get answer from upstream:", args["_answer"])
args["context"] = args["_answer"]
}
chain := chains.NewLLMChain(llm, prompt)
if history != nil {
chain.Memory = getMemory(llm, instance.Spec.Memory, history, "", "")
}
l.LLMChain = *chain

// Skip to do LLMChain again if we already has the answer, such as from agents
if args["_answer"] == nil {
var out string
if needStream, ok := args["_need_stream"].(bool); ok && needStream {
options = append(options, chains.WithStreamingFunc(stream(args)))
var out string
if needStream, ok := args["_need_stream"].(bool); ok && needStream {
options = append(options, chains.WithStreamingFunc(stream(args)))
out, err = chains.Predict(ctx, l.LLMChain, args, options...)
} else {
if len(options) > 0 {
out, err = chains.Predict(ctx, l.LLMChain, args, options...)
} else {
if len(options) > 0 {
out, err = chains.Predict(ctx, l.LLMChain, args, options...)
} else {
out, err = chains.Predict(ctx, l.LLMChain, args)
}
out, err = chains.Predict(ctx, l.LLMChain, args)
}
klog.FromContext(ctx).V(5).Info("use llmchain, blocking out:" + out)
if err == nil {
args["_answer"] = out
return args, nil
}
} else {
klog.Infoln("get answer from upstream:", args["_answer"])
}
klog.FromContext(ctx).V(5).Info("use llmchain, blocking out:" + out)
if err == nil {
args["_answer"] = out
return args, nil
}
return args, fmt.Errorf("llmchain run error: %w", err)
Expand Down
3 changes: 3 additions & 0 deletions pkg/appruntime/prompt/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ func (p *Prompt) Run(ctx context.Context, cli dynamic.Interface, args map[string
ps = append(ps, prompts.NewSystemMessagePromptTemplate(instance.Spec.SystemMessage, []string{}))
}
if instance.Spec.UserMessage != "" {
// Add the context by default, and leave it empty
// so we can add more contexts as needed in all agents/chains
instance.Spec.UserMessage = fmt.Sprintf("{{.context}}\n%s", instance.Spec.UserMessage)
ps = append(ps, prompts.NewHumanMessagePromptTemplate(instance.Spec.UserMessage, []string{"question"}))
}
template := prompts.NewChatPromptTemplate(ps)
Expand Down

0 comments on commit 3beec19

Please sign in to comment.