diff --git a/apiserver/pkg/knowledgebase/knowledgebase.go b/apiserver/pkg/knowledgebase/knowledgebase.go index 20a989196..f0287418b 100644 --- a/apiserver/pkg/knowledgebase/knowledgebase.go +++ b/apiserver/pkg/knowledgebase/knowledgebase.go @@ -122,7 +122,7 @@ func knowledgebase2model(ctx context.Context, c dynamic.Interface, obj *unstruct CreationTimestamp: &creationtimestamp, UpdateTimestamp: &condition.LastTransitionTime.Time, // Embedder info - Embedder: &embedder, + Embedder: &embedder, EmbedderType: &embedderType, // Vector info VectorStore: &generated.TypedObjectReference{ diff --git a/controllers/base/application_controller.go b/controllers/base/application_controller.go index 2e2f26b3a..871727aa5 100644 --- a/controllers/base/application_controller.go +++ b/controllers/base/application_controller.go @@ -26,6 +26,8 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" + agentv1alpha1 "github.com/kubeagi/arcadia/api/app-node/agent/v1alpha1" + chainv1alpha1 "github.com/kubeagi/arcadia/api/app-node/chain/v1alpha1" arcadiav1alpha1 "github.com/kubeagi/arcadia/api/base/v1alpha1" ) @@ -151,7 +153,8 @@ func (r *ApplicationReconciler) validateNodes(ctx context.Context, log logr.Logg r.setCondition(app, app.Status.ErrorCondition("node should have ref.group setting")...) return app, ctrl.Result{}, nil } - if *group != "chain.arcadia.kubeagi.k8s.com.cn" && *group != "agent.arcadia.kubeagi.k8s.com.cn" { + // Only allow chain group or agent node as the ending node + if *group != chainv1alpha1.Group && (*group != agentv1alpha1.Group && node.Ref.Kind != "agent") { r.setCondition(app, app.Status.ErrorCondition("ending node should be chain or agent")...) return app, ctrl.Result{}, nil } diff --git a/pkg/appruntime/agent/executor.go b/pkg/appruntime/agent/executor.go index ca4f02b39..a1dfd31a6 100644 --- a/pkg/appruntime/agent/executor.go +++ b/pkg/appruntime/agent/executor.go @@ -22,6 +22,7 @@ import ( "fmt" "github.com/tmc/langchaingo/agents" + "github.com/tmc/langchaingo/callbacks" "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/tools" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -83,8 +84,14 @@ func (p *Executor) Run(ctx context.Context, cli dynamic.Interface, args map[stri } // Initialize executor using langchaingo - options := agents.WithMaxIterations(instance.Spec.Options.MaxIterations) - executor, err := agents.Initialize(llm, allowedTools, agents.ZeroShotReactDescription, options) + executorOptions := func(o *agents.CreationOptions) { + agents.WithMaxIterations(instance.Spec.Options.MaxIterations)(o) + if needStream, ok := args["_need_stream"].(bool); ok && needStream { + streamHandler := StreamHandler{callbacks.SimpleHandler{}, args} + agents.WithCallbacksHandler(streamHandler)(o) + } + } + executor, err := agents.Initialize(llm, allowedTools, agents.ZeroShotReactDescription, executorOptions) if err != nil { return args, fmt.Errorf("failed to initialize executor: %w", err) } @@ -94,6 +101,10 @@ func (p *Executor) Run(ctx context.Context, cli dynamic.Interface, args map[stri if err != nil { return args, fmt.Errorf("error when call agent: %w", err) } - args["_answer"] = response["output"] + klog.FromContext(ctx).V(5).Info("use agent, blocking out:", response["output"]) + if err == nil { + args["_answer"] = response["output"] + return args, nil + } return args, nil } diff --git a/pkg/appruntime/agent/streamhandler.go b/pkg/appruntime/agent/streamhandler.go new file mode 100644 index 000000000..5913c13e3 --- /dev/null +++ b/pkg/appruntime/agent/streamhandler.go @@ -0,0 +1,49 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package agent + +import ( + "context" + "fmt" + + "github.com/tmc/langchaingo/callbacks" + "k8s.io/klog/v2" +) + +// StreamHandler is a callback handler that prints to the standard output streaming. +type StreamHandler struct { + callbacks.SimpleHandler + args map[string]any +} + +var _ callbacks.Handler = StreamHandler{} + +func (handler StreamHandler) HandleStreamingFunc(ctx context.Context, chunk []byte) { + logger := klog.FromContext(ctx) + if _, ok := handler.args["_answer_stream"]; !ok { + logger.Info("no _answer_stream found, create a new one") + handler.args["_answer_stream"] = make(chan string) + } + streamChan, ok := handler.args["_answer_stream"].(chan string) + if !ok { + err := fmt.Errorf("answer_stream is not chan string, but %T", handler.args["_answer_stream"]) + logger.Error(err, "answer_stream is not chan string") + return + } + logger.V(5).Info("stream out:" + string(chunk)) + streamChan <- string(chunk) +} diff --git a/pkg/appruntime/app_runtime.go b/pkg/appruntime/app_runtime.go index 5920de3f0..0f330ea63 100644 --- a/pkg/appruntime/app_runtime.go +++ b/pkg/appruntime/app_runtime.go @@ -164,7 +164,11 @@ func (a *Application) Run(ctx context.Context, cli dynamic.Interface, respStream "question": input.Question, "_answer_stream": respStream, "_history": input.History, - "context": "", + // Use an empty context before run + "context": "", + } + if input.NeedStream { + out["_need_stream"] = true } visited := make(map[string]bool) waitRunningNodes := list.New() @@ -174,7 +178,6 @@ func (a *Application) Run(ctx context.Context, cli dynamic.Interface, respStream for e := waitRunningNodes.Front(); e != nil; e = e.Next() { e := e.Value.(base.Node) if !visited[e.Name()] { - out["_need_stream"] = false reWait := false for _, n := range e.GetPrevNode() { if !visited[n.Name()] { @@ -186,9 +189,6 @@ func (a *Application) Run(ctx context.Context, cli dynamic.Interface, respStream waitRunningNodes.PushBack(e) continue } - if a.EndingNode.Name() == e.Name() && input.NeedStream { - out["_need_stream"] = true - } klog.FromContext(ctx).V(3).Info(fmt.Sprintf("try to run node:%s", e.Name())) if out, err = e.Run(ctx, cli, out); err != nil { return Output{}, fmt.Errorf("run node %s: %w", e.Name(), err) diff --git a/pkg/appruntime/chain/llmchain.go b/pkg/appruntime/chain/llmchain.go index 7132b5f4b..27d87840e 100644 --- a/pkg/appruntime/chain/llmchain.go +++ b/pkg/appruntime/chain/llmchain.go @@ -86,33 +86,31 @@ func (l *LLMChain) Run(ctx context.Context, cli dynamic.Interface, args map[stri return args, fmt.Errorf("can't convert obj to LLMChain: %w", err) } options := getChainOptions(instance.Spec.CommonChainConfig) - + // Add the answer to the context if it's not empty + if args["_answer"] != nil { + klog.Infoln("get answer from upstream:", args["_answer"]) + args["context"] = args["_answer"] + } chain := chains.NewLLMChain(llm, prompt) if history != nil { chain.Memory = getMemory(llm, instance.Spec.Memory, history, "", "") } l.LLMChain = *chain - // Skip to do LLMChain again if we already has the answer, such as from agents - if args["_answer"] == nil { - var out string - if needStream, ok := args["_need_stream"].(bool); ok && needStream { - options = append(options, chains.WithStreamingFunc(stream(args))) + var out string + if needStream, ok := args["_need_stream"].(bool); ok && needStream { + options = append(options, chains.WithStreamingFunc(stream(args))) + out, err = chains.Predict(ctx, l.LLMChain, args, options...) + } else { + if len(options) > 0 { out, err = chains.Predict(ctx, l.LLMChain, args, options...) } else { - if len(options) > 0 { - out, err = chains.Predict(ctx, l.LLMChain, args, options...) - } else { - out, err = chains.Predict(ctx, l.LLMChain, args) - } + out, err = chains.Predict(ctx, l.LLMChain, args) } - klog.FromContext(ctx).V(5).Info("use llmchain, blocking out:" + out) - if err == nil { - args["_answer"] = out - return args, nil - } - } else { - klog.Infoln("get answer from upstream:", args["_answer"]) + } + klog.FromContext(ctx).V(5).Info("use llmchain, blocking out:" + out) + if err == nil { + args["_answer"] = out return args, nil } return args, fmt.Errorf("llmchain run error: %w", err) diff --git a/pkg/appruntime/prompt/prompt.go b/pkg/appruntime/prompt/prompt.go index 8626b7210..566b2ef74 100644 --- a/pkg/appruntime/prompt/prompt.go +++ b/pkg/appruntime/prompt/prompt.go @@ -58,6 +58,9 @@ func (p *Prompt) Run(ctx context.Context, cli dynamic.Interface, args map[string ps = append(ps, prompts.NewSystemMessagePromptTemplate(instance.Spec.SystemMessage, []string{})) } if instance.Spec.UserMessage != "" { + // Add the context by default, and leave it empty + // so we can add more contexts as needed in all agents/chains + instance.Spec.UserMessage = fmt.Sprintf("{{.context}}\n%s", instance.Spec.UserMessage) ps = append(ps, prompts.NewHumanMessagePromptTemplate(instance.Spec.UserMessage, []string{"question"})) } template := prompts.NewChatPromptTemplate(ps)