diff --git a/api/app-node/common_type.go b/api/app-node/common_type.go index 18493f19b..06bc4bedc 100644 --- a/api/app-node/common_type.go +++ b/api/app-node/common_type.go @@ -25,8 +25,15 @@ import ( const ( InputLengthAnnotationKey = v1alpha1.Group + `/input-rules` OutputLengthAnnotationKey = v1alpha1.Group + `/output-rules` + + // ConversationKnowledgebaseName is the placeholder name of the conversation knowledgebase + ConversationKnowledgebaseName = "conversation_knowledgebase_" ) +func IsPlaceholderConversationKnowledgebase(name string) bool { + return name == ConversationKnowledgebaseName +} + type Ref struct { Kind string `json:"kind,omitempty"` Group string `json:"group,omitempty"` diff --git a/api/app-node/retriever/v1alpha1/mergerretriever_types.go b/api/app-node/retriever/v1alpha1/mergerretriever_types.go new file mode 100644 index 000000000..cb8991ea9 --- /dev/null +++ b/api/app-node/retriever/v1alpha1/mergerretriever_types.go @@ -0,0 +1,76 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package v1alpha1 + +import ( + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + node "github.com/kubeagi/arcadia/api/app-node" + "github.com/kubeagi/arcadia/api/base/v1alpha1" +) + +// MergerRetrieverSpec defines the desired state of MergerRetriever +type MergerRetrieverSpec struct { + v1alpha1.CommonSpec `json:",inline"` +} + +// MergerRetrieverStatus defines the observed state of MergerRetriever +type MergerRetrieverStatus struct { + // ObservedGeneration is the last observed generation. + // +optional + ObservedGeneration int64 `json:"observedGeneration,omitempty"` + + // ConditionedStatus is the current status + v1alpha1.ConditionedStatus `json:",inline"` +} + +//+kubebuilder:object:root=true +//+kubebuilder:subresource:status + +// MergerRetriever is the Schema for the MergerRetriever API +type MergerRetriever struct { + metav1.TypeMeta `json:",inline"` + metav1.ObjectMeta `json:"metadata,omitempty"` + + Spec MergerRetrieverSpec `json:"spec,omitempty"` + Status MergerRetrieverStatus `json:"status,omitempty"` +} + +//+kubebuilder:object:root=true + +// MergerRetrieverList contains a list of MergerRetriever +type MergerRetrieverList struct { + metav1.TypeMeta `json:",inline"` + metav1.ListMeta `json:"metadata,omitempty"` + Items []MergerRetriever `json:"items"` +} + +func init() { + SchemeBuilder.Register(&MergerRetriever{}, &MergerRetrieverList{}) +} + +var _ node.Node = (*MergerRetriever)(nil) + +func (c *MergerRetriever) SetRef() { + annotations := node.SetRefAnnotations(c.GetAnnotations(), []node.Ref{node.RetrieverRef.Len(1)}, []node.Ref{node.RetrievalQAChainRef.Len(1)}) + if c.GetAnnotations() == nil { + c.SetAnnotations(annotations) + } + for k, v := range annotations { + c.Annotations[k] = v + } +} diff --git a/api/app-node/retriever/v1alpha1/zz_generated.deepcopy.go b/api/app-node/retriever/v1alpha1/zz_generated.deepcopy.go index 648d0d494..be0faddb5 100644 --- a/api/app-node/retriever/v1alpha1/zz_generated.deepcopy.go +++ b/api/app-node/retriever/v1alpha1/zz_generated.deepcopy.go @@ -138,6 +138,97 @@ func (in *KnowledgeBaseRetrieverStatus) DeepCopy() *KnowledgeBaseRetrieverStatus return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *MergerRetriever) DeepCopyInto(out *MergerRetriever) { + *out = *in + out.TypeMeta = in.TypeMeta + in.ObjectMeta.DeepCopyInto(&out.ObjectMeta) + out.Spec = in.Spec + in.Status.DeepCopyInto(&out.Status) +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MergerRetriever. +func (in *MergerRetriever) DeepCopy() *MergerRetriever { + if in == nil { + return nil + } + out := new(MergerRetriever) + in.DeepCopyInto(out) + return out +} + +// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. +func (in *MergerRetriever) DeepCopyObject() runtime.Object { + if c := in.DeepCopy(); c != nil { + return c + } + return nil +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *MergerRetrieverList) DeepCopyInto(out *MergerRetrieverList) { + *out = *in + out.TypeMeta = in.TypeMeta + in.ListMeta.DeepCopyInto(&out.ListMeta) + if in.Items != nil { + in, out := &in.Items, &out.Items + *out = make([]MergerRetriever, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MergerRetrieverList. +func (in *MergerRetrieverList) DeepCopy() *MergerRetrieverList { + if in == nil { + return nil + } + out := new(MergerRetrieverList) + in.DeepCopyInto(out) + return out +} + +// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. +func (in *MergerRetrieverList) DeepCopyObject() runtime.Object { + if c := in.DeepCopy(); c != nil { + return c + } + return nil +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *MergerRetrieverSpec) DeepCopyInto(out *MergerRetrieverSpec) { + *out = *in + out.CommonSpec = in.CommonSpec +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MergerRetrieverSpec. +func (in *MergerRetrieverSpec) DeepCopy() *MergerRetrieverSpec { + if in == nil { + return nil + } + out := new(MergerRetrieverSpec) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *MergerRetrieverStatus) DeepCopyInto(out *MergerRetrieverStatus) { + *out = *in + in.ConditionedStatus.DeepCopyInto(&out.ConditionedStatus) +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MergerRetrieverStatus. +func (in *MergerRetrieverStatus) DeepCopy() *MergerRetrieverStatus { + if in == nil { + return nil + } + out := new(MergerRetrieverStatus) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *MultiQueryRetriever) DeepCopyInto(out *MultiQueryRetriever) { *out = *in diff --git a/apiserver/graph/generated/generated.go b/apiserver/graph/generated/generated.go index 9e178de29..00de60463 100644 --- a/apiserver/graph/generated/generated.go +++ b/apiserver/graph/generated/generated.go @@ -87,6 +87,7 @@ type ComplexityRoot struct { EnableRerank func(childComplexity int) int EnableUploadFile func(childComplexity int) int Knowledgebase func(childComplexity int) int + Knowledgebases func(childComplexity int) int Llm func(childComplexity int) int MaxLength func(childComplexity int) int MaxTokens func(childComplexity int) int @@ -1095,6 +1096,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Application.Knowledgebase(childComplexity), true + case "Application.knowledgebases": + if e.complexity.Application.Knowledgebases == nil { + break + } + + return e.complexity.Application.Knowledgebases(childComplexity), true + case "Application.llm": if e.complexity.Application.Llm == nil { break @@ -5164,10 +5172,15 @@ type Application { """ conversionWindowSize: Int + """ + knowledgebases 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,支持选择零个或一个或多个 + """ + knowledgebases: [String] + """ knowledgebase 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,目前一个应用只支持0或1个知识库 """ - knowledgebase: String + knowledgebase: String @deprecated(reason: "Use knowledgebases") """ scoreThreshold 最终返回结果的最低相似度 @@ -5477,7 +5490,12 @@ input UpdateApplicationConfigInput { """ knowledgebase 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,目前一个应用只支持0或1个知识库 """ - knowledgebase: String + knowledgebase: String @deprecated(reason: "Use knowledgebases") + + """ + knowledgebases 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,支持选择零个或一个或多个 + """ + knowledgebases: [String] """ scoreThreshold 最终返回结果的最低相似度 @@ -9974,6 +9992,47 @@ func (ec *executionContext) fieldContext_Application_conversionWindowSize(ctx co return fc, nil } +func (ec *executionContext) _Application_knowledgebases(ctx context.Context, field graphql.CollectedField, obj *Application) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_Application_knowledgebases(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Knowledgebases, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.([]*string) + fc.Result = res + return ec.marshalOString2ᚕᚖstring(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_Application_knowledgebases(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Application", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + func (ec *executionContext) _Application_knowledgebase(ctx context.Context, field graphql.CollectedField, obj *Application) (ret graphql.Marshaler) { fc, err := ec.fieldContext_Application_knowledgebase(ctx, field) if err != nil { @@ -11620,6 +11679,8 @@ func (ec *executionContext) fieldContext_ApplicationMutation_updateApplicationCo return ec.fieldContext_Application_maxTokens(ctx, field) case "conversionWindowSize": return ec.fieldContext_Application_conversionWindowSize(ctx, field) + case "knowledgebases": + return ec.fieldContext_Application_knowledgebases(ctx, field) case "knowledgebase": return ec.fieldContext_Application_knowledgebase(ctx, field) case "scoreThreshold": @@ -11729,6 +11790,8 @@ func (ec *executionContext) fieldContext_ApplicationQuery_getApplication(ctx con return ec.fieldContext_Application_maxTokens(ctx, field) case "conversionWindowSize": return ec.fieldContext_Application_conversionWindowSize(ctx, field) + case "knowledgebases": + return ec.fieldContext_Application_knowledgebases(ctx, field) case "knowledgebase": return ec.fieldContext_Application_knowledgebase(ctx, field) case "scoreThreshold": @@ -28770,6 +28833,8 @@ func (ec *executionContext) fieldContext_RAG_application(ctx context.Context, fi return ec.fieldContext_Application_maxTokens(ctx, field) case "conversionWindowSize": return ec.fieldContext_Application_conversionWindowSize(ctx, field) + case "knowledgebases": + return ec.fieldContext_Application_knowledgebases(ctx, field) case "knowledgebase": return ec.fieldContext_Application_knowledgebase(ctx, field) case "scoreThreshold": @@ -39102,7 +39167,7 @@ func (ec *executionContext) unmarshalInputUpdateApplicationConfigInput(ctx conte asMap[k] = v } - fieldsInOrder := [...]string{"name", "namespace", "prologue", "model", "llm", "temperature", "maxLength", "maxTokens", "conversionWindowSize", "knowledgebase", "scoreThreshold", "numDocuments", "docNullReturn", "userPrompt", "systemPrompt", "showRespInfo", "showRetrievalInfo", "showNextGuide", "tools", "enableRerank", "rerankModel", "enableMultiQuery", "chatTimeout", "enableUploadFile", "chunkSize", "chunkOverlap", "batchSize"} + fieldsInOrder := [...]string{"name", "namespace", "prologue", "model", "llm", "temperature", "maxLength", "maxTokens", "conversionWindowSize", "knowledgebase", "knowledgebases", "scoreThreshold", "numDocuments", "docNullReturn", "userPrompt", "systemPrompt", "showRespInfo", "showRetrievalInfo", "showNextGuide", "tools", "enableRerank", "rerankModel", "enableMultiQuery", "chatTimeout", "enableUploadFile", "chunkSize", "chunkOverlap", "batchSize"} for _, k := range fieldsInOrder { v, ok := asMap[k] if !ok { @@ -39179,6 +39244,13 @@ func (ec *executionContext) unmarshalInputUpdateApplicationConfigInput(ctx conte return it, err } it.Knowledgebase = data + case "knowledgebases": + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("knowledgebases")) + data, err := ec.unmarshalOString2ᚕᚖstring(ctx, v) + if err != nil { + return it, err + } + it.Knowledgebases = data case "scoreThreshold": ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("scoreThreshold")) data, err := ec.unmarshalOFloat2ᚖfloat64(ctx, v) @@ -40500,6 +40572,8 @@ func (ec *executionContext) _Application(ctx context.Context, sel ast.SelectionS out.Values[i] = ec._Application_maxTokens(ctx, field, obj) case "conversionWindowSize": out.Values[i] = ec._Application_conversionWindowSize(ctx, field, obj) + case "knowledgebases": + out.Values[i] = ec._Application_knowledgebases(ctx, field, obj) case "knowledgebase": out.Values[i] = ec._Application_knowledgebase(ctx, field, obj) case "scoreThreshold": diff --git a/apiserver/graph/generated/models_gen.go b/apiserver/graph/generated/models_gen.go index 7f4c6e10d..c3b69a368 100644 --- a/apiserver/graph/generated/models_gen.go +++ b/apiserver/graph/generated/models_gen.go @@ -54,6 +54,8 @@ type Application struct { MaxTokens *int `json:"maxTokens,omitempty"` // conversionWindowSize 对话轮次 ConversionWindowSize *int `json:"conversionWindowSize,omitempty"` + // knowledgebases 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,支持选择零个或一个或多个 + Knowledgebases []*string `json:"knowledgebases,omitempty"` // knowledgebase 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,目前一个应用只支持0或1个知识库 Knowledgebase *string `json:"knowledgebase,omitempty"` // scoreThreshold 最终返回结果的最低相似度 @@ -1698,6 +1700,8 @@ type UpdateApplicationConfigInput struct { ConversionWindowSize *int `json:"conversionWindowSize,omitempty"` // knowledgebase 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,目前一个应用只支持0或1个知识库 Knowledgebase *string `json:"knowledgebase,omitempty"` + // knowledgebases 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,支持选择零个或一个或多个 + Knowledgebases []*string `json:"knowledgebases,omitempty"` // scoreThreshold 最终返回结果的最低相似度 ScoreThreshold *float64 `json:"scoreThreshold,omitempty"` // numDocuments 最终返回结果的引用上限 diff --git a/apiserver/graph/schema/application.gql b/apiserver/graph/schema/application.gql index 25462e38b..0ab8fad26 100644 --- a/apiserver/graph/schema/application.gql +++ b/apiserver/graph/schema/application.gql @@ -75,6 +75,7 @@ mutation updateApplicationConfig($input: UpdateApplicationConfigInput!){ maxTokens conversionWindowSize knowledgebase + knowledgebases scoreThreshold numDocuments docNullReturn @@ -127,6 +128,7 @@ query getApplication($name: String!, $namespace: String!){ maxTokens conversionWindowSize knowledgebase + knowledgebases scoreThreshold numDocuments docNullReturn diff --git a/apiserver/graph/schema/application.graphqls b/apiserver/graph/schema/application.graphqls index 60b890c32..6bd6b0eb4 100644 --- a/apiserver/graph/schema/application.graphqls +++ b/apiserver/graph/schema/application.graphqls @@ -59,10 +59,15 @@ type Application { """ conversionWindowSize: Int + """ + knowledgebases 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,支持选择零个或一个或多个 + """ + knowledgebases: [String] + """ knowledgebase 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,目前一个应用只支持0或1个知识库 """ - knowledgebase: String + knowledgebase: String @deprecated(reason: "Use knowledgebases") """ scoreThreshold 最终返回结果的最低相似度 @@ -372,7 +377,12 @@ input UpdateApplicationConfigInput { """ knowledgebase 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,目前一个应用只支持0或1个知识库 """ - knowledgebase: String + knowledgebase: String @deprecated(reason: "Use knowledgebases") + + """ + knowledgebases 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,支持选择零个或一个或多个 + """ + knowledgebases: [String] """ scoreThreshold 最终返回结果的最低相似度 diff --git a/apiserver/pkg/application/application.go b/apiserver/pkg/application/application.go index 0b8c926e0..7aa100b0c 100644 --- a/apiserver/pkg/application/application.go +++ b/apiserver/pkg/application/application.go @@ -33,6 +33,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" + appnode "github.com/kubeagi/arcadia/api/app-node" apiagent "github.com/kubeagi/arcadia/api/app-node/agent/v1alpha1" apichain "github.com/kubeagi/arcadia/api/app-node/chain/v1alpha1" apidocumentloader "github.com/kubeagi/arcadia/api/app-node/documentloader/v1alpha1" @@ -65,7 +66,7 @@ func addDefaultValue(gApp *generated.Application, app *v1alpha1.Application) { if len(app.Spec.Nodes) > 0 { return } - gApp.DocNullReturn = pointer.String("未找到您询问的内容,请详细描述您的问题") + // gApp.DocNullReturn = pointer.String("未找到您询问的内容,请详细描述您的问题") gApp.NumDocuments = pointer.Int(5) gApp.ScoreThreshold = pointer.Float64(0.3) gApp.Temperature = pointer.Float64(0.7) @@ -135,7 +136,7 @@ func cr2app(prompt *apiprompt.Prompt, chainConfig *apichain.CommonChainConfig, r case "llm": gApp.Llm = node.Ref.Name case "knowledgebase": - gApp.Knowledgebase = pointer.String(node.Ref.Name) + gApp.Knowledgebases = append(gApp.Knowledgebases, pointer.String(node.Ref.Name)) } } if retriever != nil { @@ -428,6 +429,9 @@ func ListApplicationMeatadatas(ctx context.Context, c client.Client, input gener } func UpdateApplicationConfig(ctx context.Context, c client.Client, input generated.UpdateApplicationConfigInput) (*generated.Application, error) { + input.Knowledgebases = ConvertKnowledgebase2Knowledgebases(input.Knowledgebase, input.Knowledgebases) // nolint:staticcheck + hasKnowledgebase := utils.HasValues(input.Knowledgebases) || pointer.BoolDeref(input.EnableUploadFile, false) // input has knowledgebases or enable upload file(may have conversation knowledgebase) + // check tool name not duplicated if len(input.Tools) != 0 { key := make(map[string]bool, len(input.Tools)) for _, tool := range input.Tools { @@ -437,6 +441,7 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat key[tool.Name] = true } } + key := types.NamespacedName{Namespace: input.Namespace, Name: input.Name} // get application cr, if not exist, return error app := &v1alpha1.Application{} @@ -507,12 +512,43 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat } } + // create or update placeholder conversation knowledgebase, if enable upload file + var conversationKnowledgebase *v1alpha1.KnowledgeBase + if pointer.BoolDeref(input.EnableUploadFile, false) { + conversationKnowledgebase = &v1alpha1.KnowledgeBase{ + ObjectMeta: metav1.ObjectMeta{ + Name: appnode.ConversationKnowledgebaseName, + Namespace: input.Namespace, + }, + Spec: v1alpha1.KnowledgeBaseSpec{ + CommonSpec: v1alpha1.CommonSpec{ + DisplayName: "conversation knowledgebase placeholder", + Description: "conversation knowledgebase placeholder", + }, + Type: v1alpha1.KnowledgeBaseTypeConversation, + }, + } + if _, err := controllerutil.CreateOrUpdate(ctx, c, conversationKnowledgebase, func() error { + return nil + }); err != nil { + return nil, err + } + } else { + conversationKnowledgebase = &v1alpha1.KnowledgeBase{ + ObjectMeta: metav1.ObjectMeta{ + Name: appnode.ConversationKnowledgebaseName, + Namespace: input.Namespace, + }, + } + _ = c.Delete(ctx, conversationKnowledgebase) + } + // create or update chain var ( chainConfig *apichain.CommonChainConfig retriever *apiretriever.CommonRetrieverConfig ) - if utils.HasValue(input.Knowledgebase) { + if hasKnowledgebase { qachain := &apichain.RetrievalQAChain{ ObjectMeta: metav1.ObjectMeta{ Name: input.Name, @@ -544,6 +580,13 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat }); err != nil { return nil, err } + llmchain := &apichain.LLMChain{ + ObjectMeta: metav1.ObjectMeta{ + Name: input.Name, + Namespace: input.Namespace, + }, + } + _ = c.Delete(ctx, llmchain) chainConfig = &qachain.Spec.CommonChainConfig } else { llmchain := &apichain.LLMChain{ @@ -577,12 +620,22 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat }); err != nil { return nil, err } + qachain := &apichain.RetrievalQAChain{ + ObjectMeta: metav1.ObjectMeta{ + Name: input.Name, + Namespace: input.Namespace, + }, + } + _ = c.Delete(ctx, qachain) chainConfig = &llmchain.Spec.CommonChainConfig } // create or update retrievers - // knowledgebaseRetriever (must have) -> multiQueryRetriever (optional) -> rerankRetriever (optional) -> Output - hasKnowledgebaseRetriever := utils.HasValue(input.Knowledgebase) + // (must have) (must have) (optional) (optional) + // knowledgebase1 -> knowledgebaseRetriever1 + // knowledgebase2 -> knowledgebaseRetriever2 --> mergerRetriever --> multiQueryRetriever --> rerankRetriever --> Output + // conversation knowledgebase (placeholder or has true data) -> knowledgebaseRetriever3 + hasKnowledgebaseRetriever := hasKnowledgebase hasMultiQueryRetriever := hasKnowledgebaseRetriever && pointer.BoolDeref(input.EnableMultiQuery, false) hasRerankRetriever := hasKnowledgebaseRetriever && pointer.BoolDeref(input.EnableRerank, false) rerankModel := "" @@ -614,6 +667,36 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat return nil, err } retriever = &knowledgebaseRetriever.Spec.CommonRetrieverConfig + } else { + knowledgebaseRetriever = &apiretriever.KnowledgeBaseRetriever{ + ObjectMeta: metav1.ObjectMeta{ + Name: input.Name, + Namespace: input.Namespace, + }, + } + _ = c.Delete(ctx, knowledgebaseRetriever) + } + + if hasKnowledgebaseRetriever { + mergerRetriever := &apiretriever.MergerRetriever{ + ObjectMeta: metav1.ObjectMeta{ + Name: input.Name, + Namespace: input.Namespace, + }, + } + if _, err = controllerutil.CreateOrUpdate(ctx, c, mergerRetriever, func() error { + return nil + }); err != nil { + return nil, err + } + } else { + mergerRetriever := &apiretriever.MergerRetriever{ + ObjectMeta: metav1.ObjectMeta{ + Name: input.Name, + Namespace: input.Namespace, + }, + } + _ = c.Delete(ctx, mergerRetriever) } if hasMultiQueryRetriever { @@ -650,7 +733,16 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat return nil, err } retriever = &multiQueryRetriever.Spec.CommonRetrieverConfig + } else { + multiQueryRetriever := &apiretriever.MultiQueryRetriever{ + ObjectMeta: metav1.ObjectMeta{ + Name: input.Name, + Namespace: input.Namespace, + }, + } + _ = c.Delete(ctx, multiQueryRetriever) } + if hasRerankRetriever { rerankRetriever := &apiretriever.RerankRetriever{ ObjectMeta: metav1.ObjectMeta{ @@ -692,17 +784,7 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat if rerankRetriever.Spec.Model != nil { rerankModel = rerankRetriever.Spec.Model.Name } - } - if !hasMultiQueryRetriever { - multiQueryRetriever := &apiretriever.MultiQueryRetriever{ - ObjectMeta: metav1.ObjectMeta{ - Name: input.Name, - Namespace: input.Namespace, - }, - } - _ = c.Delete(ctx, multiQueryRetriever) - } - if !hasRerankRetriever { + } else { reRankRetriever := &apiretriever.RerankRetriever{ ObjectMeta: metav1.ObjectMeta{ Name: input.Name, @@ -711,6 +793,7 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat } _ = c.Delete(ctx, reRankRetriever) } + // create or update agent for tools var agent *apiagent.Agent if len(input.Tools) != 0 { @@ -762,7 +845,7 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat } func mutateApp(app *v1alpha1.Application, input generated.UpdateApplicationConfigInput, hasMultiQueryRetriever, hasRerankRetriever bool) error { - app.Spec.Nodes = redefineNodes(input.Knowledgebase, input.Namespace, input.Name, input.Llm, input.Tools, hasMultiQueryRetriever, hasRerankRetriever, input.EnableUploadFile) + app.Spec.Nodes = redefineNodes(input.Knowledgebases, input.Namespace, input.Name, input.Llm, input.Tools, hasMultiQueryRetriever, hasRerankRetriever, input.EnableUploadFile) app.Spec.Prologue = pointer.StringDeref(input.Prologue, app.Spec.Prologue) app.Spec.ShowRespInfo = pointer.BoolDeref(input.ShowRespInfo, app.Spec.ShowRespInfo) app.Spec.ShowRetrievalInfo = pointer.BoolDeref(input.ShowRetrievalInfo, app.Spec.ShowRetrievalInfo) @@ -776,7 +859,7 @@ func mutateApp(app *v1alpha1.Application, input generated.UpdateApplicationConfi } // redefineNodes redefine nodes in application -func redefineNodes(knowledgebase *string, namespace string, name string, llmName string, tools []*generated.ToolInput, hasMultiQueryRetriever, hasRerankRetriever bool, enableUploadFile *bool) (nodes []v1alpha1.Node) { +func redefineNodes(knowledgebases []*string, namespace string, name string, llmName string, tools []*generated.ToolInput, hasMultiQueryRetriever, hasRerankRetriever bool, enableUploadFile *bool) (nodes []v1alpha1.Node) { nodes = []v1alpha1.Node{ { NodeConfig: v1alpha1.NodeConfig{ @@ -805,41 +888,25 @@ func redefineNodes(knowledgebase *string, namespace string, name string, llmName }, NextNodeName: []string{"chain-node"}, }, - } - if pointer.BoolDeref(enableUploadFile, false) { - nodes = append(nodes, v1alpha1.Node{ + { NodeConfig: v1alpha1.NodeConfig{ - Name: "documentloader-node", - DisplayName: "documentloader", - Description: "文档加载,可选", + Name: "llm-node", + DisplayName: "llm", + Description: "设定大模型的访问信息", Ref: &v1alpha1.TypedObjectReference{ APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"), - Kind: "DocumentLoader", - Name: name, + Kind: "LLM", + Name: llmName, Namespace: &namespace, }, }, NextNodeName: []string{"chain-node"}, - }) - } - nodes = append(nodes, v1alpha1.Node{ - NodeConfig: v1alpha1.NodeConfig{ - Name: "llm-node", - DisplayName: "llm", - Description: "设定大模型的访问信息", - Ref: &v1alpha1.TypedObjectReference{ - APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"), - Kind: "LLM", - Name: llmName, - Namespace: &namespace, - }, }, - NextNodeName: []string{"chain-node"}, - }) + } if len(tools) != 0 { nodes[len(nodes)-1].NextNodeName = []string{"chain-node", "agent-node"} } - if knowledgebase == nil { + if !utils.HasValues(knowledgebases) { nodes = append(nodes, v1alpha1.Node{ NodeConfig: v1alpha1.NodeConfig{ Name: "chain-node", @@ -855,50 +922,66 @@ func redefineNodes(knowledgebase *string, namespace string, name string, llmName NextNodeName: []string{"Output"}, }) } else { - nodes = append(nodes, - v1alpha1.Node{ - NodeConfig: v1alpha1.NodeConfig{ - Name: "knowledgebase-node", - DisplayName: "知识库", - Description: "连接知识库", - Ref: &v1alpha1.TypedObjectReference{ - APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"), - Kind: "KnowledgeBase", - Name: pointer.StringDeref(knowledgebase, ""), - Namespace: &namespace, + for _, knowledgebase := range knowledgebases { + nodes = append(nodes, + v1alpha1.Node{ + NodeConfig: v1alpha1.NodeConfig{ + Name: "knowledgebase-node", + DisplayName: "知识库", + Description: "连接知识库", + Ref: &v1alpha1.TypedObjectReference{ + APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"), + Kind: "KnowledgeBase", + Name: pointer.StringDeref(knowledgebase, ""), + Namespace: &namespace, + }, }, + NextNodeName: []string{"retriever-node"}, }, - NextNodeName: []string{"retriever-node"}, - }, + v1alpha1.Node{ + NodeConfig: v1alpha1.NodeConfig{ + Name: "retriever-node", + DisplayName: "从知识库提取信息的retriever", + Description: "连接应用和知识库", + Ref: &v1alpha1.TypedObjectReference{ + APIGroup: pointer.String("retriever.arcadia.kubeagi.k8s.com.cn"), + Kind: "KnowledgeBaseRetriever", + Name: name, + Namespace: &namespace, + }, + }, + NextNodeName: []string{"mergerretriever-node"}, + }) + } + mergerRetrierNextNodeName := "chain-node" + multiqueryRetrieverNodeName := "chain-node" + switch { + case hasMultiQueryRetriever: + mergerRetrierNextNodeName = "multiqueryretriever-node" + if hasRerankRetriever { + multiqueryRetrieverNodeName = "rerankretriever-node" + } + case !hasMultiQueryRetriever && hasRerankRetriever: + mergerRetrierNextNodeName = "rerankretriever-node" + case !hasMultiQueryRetriever && !hasRerankRetriever: + mergerRetrierNextNodeName = "chain-node" + } + nodes = append(nodes, v1alpha1.Node{ NodeConfig: v1alpha1.NodeConfig{ - Name: "retriever-node", - DisplayName: "从知识库提取信息的retriever", - Description: "连接应用和知识库", + Name: "mergerretriever-node", + DisplayName: "知识库合并retriever", + Description: "知识库合并retriever", Ref: &v1alpha1.TypedObjectReference{ APIGroup: pointer.String("retriever.arcadia.kubeagi.k8s.com.cn"), - Kind: "KnowledgeBaseRetriever", + Kind: "MergerRetriever", Name: name, Namespace: &namespace, }, }, - NextNodeName: []string{"chain-node"}, + NextNodeName: []string{mergerRetrierNextNodeName}, }) - knowledgebaseRetrierNextNodeName := "chain-node" - switch { - case hasMultiQueryRetriever: - knowledgebaseRetrierNextNodeName = "multiqueryretriever-node" - case !hasMultiQueryRetriever && hasRerankRetriever: - knowledgebaseRetrierNextNodeName = "rerankretriever-node" - case !hasMultiQueryRetriever && !hasRerankRetriever: - knowledgebaseRetrierNextNodeName = "chain-node" - } - nodes[len(nodes)-1].NextNodeName = []string{knowledgebaseRetrierNextNodeName} if hasMultiQueryRetriever { - nextNodeName := "chain-node" - if hasRerankRetriever { - nextNodeName = "rerankretriever-node" - } nodes = append(nodes, v1alpha1.Node{ NodeConfig: v1alpha1.NodeConfig{ @@ -912,7 +995,7 @@ func redefineNodes(knowledgebase *string, namespace string, name string, llmName Namespace: &namespace, }, }, - NextNodeName: []string{nextNodeName}, + NextNodeName: []string{multiqueryRetrieverNodeName}, }) } if hasRerankRetriever { @@ -948,6 +1031,7 @@ func redefineNodes(knowledgebase *string, namespace string, name string, llmName NextNodeName: []string{"Output"}, }) } + if len(tools) != 0 { nodes = append(nodes, v1alpha1.Node{ NodeConfig: v1alpha1.NodeConfig{ @@ -964,6 +1048,24 @@ func redefineNodes(knowledgebase *string, namespace string, name string, llmName NextNodeName: []string{"chain-node"}, }) } + + if pointer.BoolDeref(enableUploadFile, false) { + nodes = append(nodes, v1alpha1.Node{ + NodeConfig: v1alpha1.NodeConfig{ + Name: "documentloader-node", + DisplayName: "documentloader", + Description: "文档加载,可选", + Ref: &v1alpha1.TypedObjectReference{ + APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"), + Kind: "DocumentLoader", + Name: name, + Namespace: &namespace, + }, + }, + NextNodeName: []string{"chain-node"}, + }) + } + nodes = append(nodes, v1alpha1.Node{ NodeConfig: v1alpha1.NodeConfig{ Name: "Output", @@ -979,6 +1081,17 @@ func redefineNodes(knowledgebase *string, namespace string, name string, llmName return nodes } +// deprecated +func ConvertKnowledgebase2Knowledgebases(knowledgebase *string, knowledgebases []*string) []*string { + if knowledgebases != nil { + return knowledgebases + } + if knowledgebase == nil { + return nil + } + return []*string{knowledgebase} +} + func UploadIcon(ctx context.Context, client client.Client, icon, appName, namespace string) (string, error) { if strings.HasPrefix(icon, "data:image") { imgBytes, err := pkgutils.ParseBase64ImageBytes(icon) diff --git a/apiserver/pkg/chat/chat_server.go b/apiserver/pkg/chat/chat_server.go index 8725abb81..c084b2761 100644 --- a/apiserver/pkg/chat/chat_server.go +++ b/apiserver/pkg/chat/chat_server.go @@ -29,7 +29,6 @@ import ( langchainllms "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/memory" "github.com/tmc/langchaingo/prompts" - langchainschema "github.com/tmc/langchaingo/schema" "golang.org/x/sync/errgroup" "k8s.io/apimachinery/pkg/types" "k8s.io/klog/v2" @@ -302,32 +301,29 @@ func (cs *ChatServer) ListPromptStarters(ctx context.Context, req APPMetadata, l if finish != nil { defer finish() } - v, ok := outArg[base.LangchaingoRetrieverKeyInArg] - if ok { - r, ok := v.(langchainschema.Retriever) - if ok { - doc, err := r.GetRelevantDocuments(ctx, "") - if err != nil { - return nil, err - } - for _, d := range doc { - hasAnswer := false - // has answer, means qa.csv, just return the question - v, ok := d.Metadata[documentloaders.AnswerCol] - if ok { - answer, ok := v.(string) - if ok && answer != "" { - question := strings.TrimSuffix(d.PageContent, "\na: "+answer) - promptStarters = append(promptStarters, strings.TrimPrefix(question, "q: ")) - hasAnswer = true - if len(promptStarters) == limit { - break - } + retrievers, err := base.GetRetrieversFromArg(outArg) + if err == nil && len(retrievers) > 0 { + doc, err := retrievers[0].GetRelevantDocuments(ctx, "") + if err != nil { + return nil, err + } + for _, d := range doc { + hasAnswer := false + // has answer, means qa.csv, just return the question + v, ok := d.Metadata[documentloaders.AnswerCol] + if ok { + answer, ok := v.(string) + if ok && answer != "" { + question := strings.TrimSuffix(d.PageContent, "\na: "+answer) + promptStarters = append(promptStarters, strings.TrimPrefix(question, "q: ")) + hasAnswer = true + if len(promptStarters) == limit { + break } } - if !hasAnswer { - content.WriteString(d.PageContent + "\n") - } + } + if !hasAnswer { + content.WriteString(d.PageContent + "\n") } } } diff --git a/apiserver/pkg/utils/structured.go b/apiserver/pkg/utils/structured.go index 20d1916dc..f72c319e0 100644 --- a/apiserver/pkg/utils/structured.go +++ b/apiserver/pkg/utils/structured.go @@ -38,3 +38,7 @@ func MapStr2Any(input map[string]string) map[string]any { func HasValue(s *string) bool { return s != nil && strings.TrimSpace(*s) != "" } + +func HasValues(s []*string) bool { + return len(s) > 0 && strings.TrimSpace(*s[0]) != "" +} diff --git a/config/crd/bases/retriever.arcadia.kubeagi.k8s.com.cn_mergerretrievers.yaml b/config/crd/bases/retriever.arcadia.kubeagi.k8s.com.cn_mergerretrievers.yaml new file mode 100644 index 000000000..590f3797a --- /dev/null +++ b/config/crd/bases/retriever.arcadia.kubeagi.k8s.com.cn_mergerretrievers.yaml @@ -0,0 +1,98 @@ +--- +apiVersion: apiextensions.k8s.io/v1 +kind: CustomResourceDefinition +metadata: + annotations: + controller-gen.kubebuilder.io/version: v0.9.2 + creationTimestamp: null + name: mergerretrievers.retriever.arcadia.kubeagi.k8s.com.cn +spec: + group: retriever.arcadia.kubeagi.k8s.com.cn + names: + kind: MergerRetriever + listKind: MergerRetrieverList + plural: mergerretrievers + singular: mergerretriever + scope: Namespaced + versions: + - name: v1alpha1 + schema: + openAPIV3Schema: + description: MergerRetriever is the Schema for the MergerRetriever API + properties: + apiVersion: + description: 'APIVersion defines the versioned schema of this representation + of an object. Servers should convert recognized schemas to the latest + internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources' + type: string + kind: + description: 'Kind is a string value representing the REST resource this + object represents. Servers may infer this from the endpoint the client + submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds' + type: string + metadata: + type: object + spec: + description: MergerRetrieverSpec defines the desired state of MergerRetriever + properties: + creator: + description: Creator defines datasource creator (AUTO-FILLED by webhook) + type: string + description: + description: Description defines datasource description + type: string + displayName: + description: DisplayName defines datasource display name + type: string + type: object + status: + description: MergerRetrieverStatus defines the observed state of MergerRetriever + properties: + conditions: + description: Conditions of the resource. + items: + description: A Condition that may apply to a resource. + properties: + lastSuccessfulTime: + description: LastSuccessfulTime is repository Last Successful + Update Time + format: date-time + type: string + lastTransitionTime: + description: LastTransitionTime is the last time this condition + transitioned from one status to another. + format: date-time + type: string + message: + description: A Message containing details about this condition's + last transition from one status to another, if any. + type: string + reason: + description: A Reason for this condition's last transition from + one status to another. + type: string + status: + description: Status of this condition; is it currently True, + False, or Unknown + type: string + type: + description: Type of this condition. At most one of each condition + type may apply to a resource at any point in time. + type: string + required: + - lastTransitionTime + - reason + - status + - type + type: object + type: array + observedGeneration: + description: ObservedGeneration is the last observed generation. + format: int64 + type: integer + type: object + type: object + served: true + storage: true + subresources: + status: {} diff --git a/deploy/charts/arcadia/Chart.yaml b/deploy/charts/arcadia/Chart.yaml index a74ee8d19..4f25507cf 100644 --- a/deploy/charts/arcadia/Chart.yaml +++ b/deploy/charts/arcadia/Chart.yaml @@ -2,7 +2,7 @@ apiVersion: v2 name: arcadia description: A Helm chart(Also a KubeBB Component) for KubeAGI Arcadia type: application -version: 0.3.29 +version: 0.3.30 appVersion: "0.2.1" keywords: diff --git a/deploy/charts/arcadia/crds/retriever.arcadia.kubeagi.k8s.com.cn_mergerretrievers.yaml b/deploy/charts/arcadia/crds/retriever.arcadia.kubeagi.k8s.com.cn_mergerretrievers.yaml new file mode 100644 index 000000000..590f3797a --- /dev/null +++ b/deploy/charts/arcadia/crds/retriever.arcadia.kubeagi.k8s.com.cn_mergerretrievers.yaml @@ -0,0 +1,98 @@ +--- +apiVersion: apiextensions.k8s.io/v1 +kind: CustomResourceDefinition +metadata: + annotations: + controller-gen.kubebuilder.io/version: v0.9.2 + creationTimestamp: null + name: mergerretrievers.retriever.arcadia.kubeagi.k8s.com.cn +spec: + group: retriever.arcadia.kubeagi.k8s.com.cn + names: + kind: MergerRetriever + listKind: MergerRetrieverList + plural: mergerretrievers + singular: mergerretriever + scope: Namespaced + versions: + - name: v1alpha1 + schema: + openAPIV3Schema: + description: MergerRetriever is the Schema for the MergerRetriever API + properties: + apiVersion: + description: 'APIVersion defines the versioned schema of this representation + of an object. Servers should convert recognized schemas to the latest + internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources' + type: string + kind: + description: 'Kind is a string value representing the REST resource this + object represents. Servers may infer this from the endpoint the client + submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds' + type: string + metadata: + type: object + spec: + description: MergerRetrieverSpec defines the desired state of MergerRetriever + properties: + creator: + description: Creator defines datasource creator (AUTO-FILLED by webhook) + type: string + description: + description: Description defines datasource description + type: string + displayName: + description: DisplayName defines datasource display name + type: string + type: object + status: + description: MergerRetrieverStatus defines the observed state of MergerRetriever + properties: + conditions: + description: Conditions of the resource. + items: + description: A Condition that may apply to a resource. + properties: + lastSuccessfulTime: + description: LastSuccessfulTime is repository Last Successful + Update Time + format: date-time + type: string + lastTransitionTime: + description: LastTransitionTime is the last time this condition + transitioned from one status to another. + format: date-time + type: string + message: + description: A Message containing details about this condition's + last transition from one status to another, if any. + type: string + reason: + description: A Reason for this condition's last transition from + one status to another. + type: string + status: + description: Status of this condition; is it currently True, + False, or Unknown + type: string + type: + description: Type of this condition. At most one of each condition + type may apply to a resource at any point in time. + type: string + required: + - lastTransitionTime + - reason + - status + - type + type: object + type: array + observedGeneration: + description: ObservedGeneration is the last observed generation. + format: int64 + type: integer + type: object + type: object + served: true + storage: true + subresources: + status: {} diff --git a/go.mod b/go.mod index 8f8600ff6..d4076a31f 100644 --- a/go.mod +++ b/go.mod @@ -215,4 +215,4 @@ require ( sigs.k8s.io/yaml v1.3.0 // indirect ) -replace github.com/tmc/langchaingo => github.com/kubeagi/langchaingo v0.0.0-20240312075057-ca2f549e8d91 // branch dev +replace github.com/tmc/langchaingo => github.com/Abirdcfly/langchaingo v0.0.0-20240402052449-3e0b4b8fedf4 // branch dev diff --git a/go.sum b/go.sum index b8718c1b0..d1126d6a5 100644 --- a/go.sum +++ b/go.sum @@ -57,6 +57,8 @@ dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= github.com/99designs/gqlgen v0.17.40 h1:/l8JcEVQ93wqIfmH9VS1jsAkwm6eAF1NwQn3N+SDqBY= github.com/99designs/gqlgen v0.17.40/go.mod h1:b62q1USk82GYIVjC60h02YguAZLqYZtvWml8KkhJps4= +github.com/Abirdcfly/langchaingo v0.0.0-20240402052449-3e0b4b8fedf4 h1:32Ne6AEwWJr8Hi40LgI+i5oVc8nB3hbJZ8g3hcEaqb8= +github.com/Abirdcfly/langchaingo v0.0.0-20240402052449-3e0b4b8fedf4/go.mod h1:RLtnUED/hH2v765vdjS9Z6gonErZAXURuJHph0BttqM= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= @@ -549,8 +551,6 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/kubeagi/langchaingo v0.0.0-20240312075057-ca2f549e8d91 h1:4VbKHgpTrG/EPWIn4n3FMIedvz3z2aYL2E0Z7RIEvik= -github.com/kubeagi/langchaingo v0.0.0-20240312075057-ca2f549e8d91/go.mod h1:RLtnUED/hH2v765vdjS9Z6gonErZAXURuJHph0BttqM= github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80 h1:6Yzfa6GP0rIo/kULo2bwGEkFvCePZ3qHDDTC3/J9Swo= github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80/go.mod h1:imJHygn/1yfhB7XSJJKlFZKl/J+dCPAknuiaGOshXAs= github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= diff --git a/pkg/appruntime/app_runtime.go b/pkg/appruntime/app_runtime.go index 9045ca778..abee5dc50 100644 --- a/pkg/appruntime/app_runtime.go +++ b/pkg/appruntime/app_runtime.go @@ -25,8 +25,6 @@ import ( "strings" langchaingoschema "github.com/tmc/langchaingo/schema" - apierrors "k8s.io/apimachinery/pkg/api/errors" - "k8s.io/apimachinery/pkg/types" "k8s.io/klog/v2" "k8s.io/utils/strings/slices" "sigs.k8s.io/controller-runtime/pkg/client" @@ -149,32 +147,12 @@ func (a *Application) Run(ctx context.Context, cli client.Client, respStream cha base.InputIsNeedStreamKeyInArg: input.NeedStream, base.LangchaingoChatMessageHistoryKeyInArg: input.History, // Use an empty context before run - "context": "", + "context": "", + base.ConversationIDInArg: input.ConversationID, } if a.Spec.DocNullReturn != "" { out[base.APPDocNullReturn] = a.Spec.DocNullReturn } - if input.ConversationID != "" { // means this is not a new conversation - conversationKnowledgebaseExist := true - kb := &arcadiav1alpha1.KnowledgeBase{} - err := cli.Get(ctx, types.NamespacedName{Namespace: a.Namespace, Name: input.ConversationID}, kb) - if err != nil { - if apierrors.IsNotFound(err) { - conversationKnowledgebaseExist = false - // TODO We can search for whether there should be a conversation knowledgebase from the pg - klog.FromContext(ctx).V(5).Info("conversation knowledgebase not exist", "ConversationID", input.ConversationID) - } else { - return output, err - } - } - if conversationKnowledgebaseExist { - if kb.Status.IsReady() { - out[base.ConversationKnowledgeBaseInArg] = kb - } else { - klog.FromContext(ctx).V(3).Info("conversation knowledgebase not ready", "ConversationID", input.ConversationID) - } - } - } visited := make(map[string]bool) waitRunningNodes := list.New() for _, v := range a.StartingNodes { @@ -280,6 +258,9 @@ func InitNode(ctx context.Context, appNamespace, name string, ref arcadiav1alpha case "multiqueryretriever": logger.V(3).Info("initnode multiqueryretriever") return retriever.NewMultiQueryRetriever(baseNode), nil + case "mergeretriever": + logger.V(3).Info("initnode mergeretriever") + return retriever.NewMergerRetriever(baseNode), nil default: return nil, err } diff --git a/pkg/appruntime/base/keyword.go b/pkg/appruntime/base/keyword.go index 465e19443..70c26da3d 100644 --- a/pkg/appruntime/base/keyword.go +++ b/pkg/appruntime/base/keyword.go @@ -16,6 +16,12 @@ limitations under the License. package base +import ( + "errors" + + langchainschema "github.com/tmc/langchaingo/schema" +) + const ( InputQuestionKeyInArg = "question" InputIsNeedStreamKeyInArg = "_need_stream" @@ -25,9 +31,46 @@ const ( MapReduceDocumentOutputInArg = "_mapreduce_document_answer" OutputAnserStreamChanKeyInArg = "_answer_stream" RuntimeRetrieverReferencesKeyInArg = "_references" - LangchaingoRetrieverKeyInArg = "retriever" + LangchaingoRetrieversKeyInArg = "retrievers" LangchaingoLLMKeyInArg = "llm" LangchaingoPromptKeyInArg = "prompt" APPDocNullReturn = "_app_doc_null_return" ConversationKnowledgeBaseInArg = "_conversation_knowledgebase" // the conversation Knowledgebase cr in args, status has ready + ConversationIDInArg = "_conversation_id" ) + +func GetInputQuestionFromArg(args map[string]any) (string, error) { + q, ok := args[InputQuestionKeyInArg] + if !ok { + return "", errors.New("no question in args") + } + query, ok := q.(string) + if !ok || len(query) == 0 { + return "", errors.New("empty question") + } + return query, nil +} + +func GetRetrieversFromArg(args map[string]any) ([]langchainschema.Retriever, error) { + v, ok := args[LangchaingoRetrieversKeyInArg] + if !ok { + return nil, errors.New("no retrievers in args") + } + retrievers, ok := v.([]langchainschema.Retriever) + if !ok { + return nil, errors.New("retrievers not []schema.Retriever") + } + return retrievers, nil +} + +func GetAPPDocNullReturnFromArg(args map[string]any) (string, error) { + v, ok := args[APPDocNullReturn] + if !ok { + return "", nil + } + docNullReturn, ok := v.(string) + if !ok { + return "", errors.New("app doc null return not string type") + } + return docNullReturn, nil +} diff --git a/pkg/appruntime/knowledgebase/knowledgebase.go b/pkg/appruntime/knowledgebase/knowledgebase.go index e57ee21d1..8448f7a36 100644 --- a/pkg/appruntime/knowledgebase/knowledgebase.go +++ b/pkg/appruntime/knowledgebase/knowledgebase.go @@ -23,6 +23,7 @@ import ( "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/client" + appnode "github.com/kubeagi/arcadia/api/app-node" "github.com/kubeagi/arcadia/api/base/v1alpha1" "github.com/kubeagi/arcadia/pkg/appruntime/base" ) @@ -39,6 +40,9 @@ func NewKnowledgebase(baseNode base.BaseNode) *Knowledgebase { } func (k *Knowledgebase) Init(ctx context.Context, cli client.Client, _ map[string]any) error { + if appnode.IsPlaceholderConversationKnowledgebase(k.Ref.Name) { + return nil + } instance := &v1alpha1.KnowledgeBase{} if err := cli.Get(ctx, types.NamespacedName{Namespace: k.RefNamespace(), Name: k.Ref.Name}, instance); err != nil { return fmt.Errorf("can't find the knowledgebase in cluster: %w", err) @@ -47,10 +51,9 @@ func (k *Knowledgebase) Init(ctx context.Context, cli client.Client, _ map[strin return nil } -func (k *Knowledgebase) Run(_ context.Context, _ client.Client, args map[string]any) (map[string]any, error) { - return args, nil -} - func (k *Knowledgebase) Ready() (isReady bool, msg string) { + if appnode.IsPlaceholderConversationKnowledgebase(k.Ref.Name) { + return true, "" + } return k.Instance.Status.IsReadyOrGetReadyMessage() } diff --git a/pkg/appruntime/retriever/knowledgebaseretriever.go b/pkg/appruntime/retriever/knowledgebaseretriever.go index 28450a6fd..c9603eb40 100644 --- a/pkg/appruntime/retriever/knowledgebaseretriever.go +++ b/pkg/appruntime/retriever/knowledgebaseretriever.go @@ -18,16 +18,16 @@ package retriever import ( "context" - "errors" "fmt" - langchaingoschema "github.com/tmc/langchaingo/schema" "github.com/tmc/langchaingo/vectorstores" + apierrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/types" "k8s.io/klog/v2" "k8s.io/utils/pointer" "sigs.k8s.io/controller-runtime/pkg/client" + appnode "github.com/kubeagi/arcadia/api/app-node" apiretriever "github.com/kubeagi/arcadia/api/app-node/retriever/v1alpha1" "github.com/kubeagi/arcadia/api/base/v1alpha1" "github.com/kubeagi/arcadia/pkg/appruntime/base" @@ -101,7 +101,21 @@ func (l *KnowledgeBaseRetriever) Cleanup() { func GenerateKnowledgebaseRetriever(ctx context.Context, cli client.Client, knowledgebaseName, knowledgebaseNamespace string, retrieverConfig apiretriever.CommonRetrieverConfig, args map[string]any) (outArg map[string]any, finish func(), err error) { knowledgebase := &v1alpha1.KnowledgeBase{} + isConversationKnowledgebase := appnode.IsPlaceholderConversationKnowledgebase(knowledgebaseName) + if isConversationKnowledgebase { + v, ok := args[base.ConversationIDInArg] + if ok { + conversationID, ok := v.(string) + if ok && conversationID != "" { + knowledgebaseName = conversationID + } + } + } if err := cli.Get(ctx, types.NamespacedName{Namespace: knowledgebaseNamespace, Name: knowledgebaseName}, knowledgebase); err != nil { + if isConversationKnowledgebase && apierrors.IsNotFound(err) { // When there is a conversationID, look for the corresponding conversation knowledgebase. This knowledgebase may not exist. This is not a error + // TODO We can search for whether there should be a conversation knowledgebase from the pg + return args, nil, nil + } return nil, nil, fmt.Errorf("can't find the knowledgebase in cluster: %w", err) } @@ -138,40 +152,14 @@ func GenerateKnowledgebaseRetriever(ctx context.Context, cli client.Client, know } retriever.CallbacksHandler = log.KLogHandler{LogLevel: 3} - question, ok := args["question"] - if !ok { - return nil, finish, errors.New("no question in args") - } - query, ok := question.(string) - if !ok { - return nil, finish, errors.New("question not string") + query, err := base.GetInputQuestionFromArg(args) + if err != nil { + return nil, finish, err } docs, err := retriever.GetRelevantDocuments(ctx, query) if err != nil { return nil, finish, fmt.Errorf("can't get relevant documents: %w", err) } - oldDocs := make([]langchaingoschema.Document, 0) - v, ok := args[base.LangchaingoRetrieverKeyInArg] - if ok { - // may exist other retriever, like conversation retriever - oldRetriever, ok := v.(langchaingoschema.Retriever) - if ok { - oldDocs, err = oldRetriever.GetRelevantDocuments(ctx, query) - if err != nil { - return nil, finish, fmt.Errorf("can't get old doc: %w", err) - } - } - } - if len(docs) == 0 && len(oldDocs) == 0 { - // FIXME: 需要决定当知识库找不到相关内容,但是conversation知识库存在文档时如何处理 - v, exist := args[base.APPDocNullReturn] - if exist { - docNullReturn, ok := v.(string) - if ok && len(docNullReturn) > 0 { - return nil, finish, &base.RetrieverGetNullDocError{Msg: docNullReturn} - } - } - } // pgvector get score means vector distance, similarity = 1 - vector distance // chroma get score means similarity // we want similarity finally. @@ -181,7 +169,7 @@ func GenerateKnowledgebaseRetriever(ctx context.Context, cli client.Client, know } } docs, refs := ConvertDocuments(ctx, docs, "knowledgebase") - args[base.LangchaingoRetrieverKeyInArg] = &Fakeretriever{Docs: append(docs, oldDocs...), Name: "KnowledgebaseRetriever"} + args[base.LangchaingoRetrieversKeyInArg] = &Fakeretriever{Docs: docs, Name: "KnowledgebaseRetriever"} AddReferencesToArgs(args, refs) return args, finish, nil } diff --git a/pkg/appruntime/retriever/mergerretriever.go b/pkg/appruntime/retriever/mergerretriever.go new file mode 100644 index 000000000..17fbdcde4 --- /dev/null +++ b/pkg/appruntime/retriever/mergerretriever.go @@ -0,0 +1,79 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package retriever + +import ( + "context" + "fmt" + + langchainretrievers "github.com/tmc/langchaingo/retrievers" + langchainschema "github.com/tmc/langchaingo/schema" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + + apiretriever "github.com/kubeagi/arcadia/api/app-node/retriever/v1alpha1" + "github.com/kubeagi/arcadia/pkg/appruntime/base" +) + +type MergerRetriever struct { + base.BaseNode + Instance *apiretriever.MergerRetriever +} + +func NewMergerRetriever(baseNode base.BaseNode) *MergerRetriever { + return &MergerRetriever{ + BaseNode: baseNode, + } +} + +func (l *MergerRetriever) Init(ctx context.Context, cli client.Client, _ map[string]any) error { + instance := &apiretriever.MergerRetriever{} + if err := cli.Get(ctx, types.NamespacedName{Namespace: l.RefNamespace(), Name: l.BaseNode.Ref.Name}, instance); err != nil { + return fmt.Errorf("can't find the merger retriever in cluster: %w", err) + } + l.Instance = instance + return nil +} + +func (l *MergerRetriever) Run(ctx context.Context, _ client.Client, args map[string]any) (map[string]any, error) { + retrievers, err := base.GetRetrieversFromArg(args) + if err != nil { + return args, err + } + r := langchainretrievers.NewMergerRetriever(retrievers) + query, err := base.GetInputQuestionFromArg(args) + if err != nil { + return args, err + } + docs, err := r.GetRelevantDocuments(ctx, query) + if err != nil { + return args, err + } + if len(docs) == 0 { + docNullReturn, err := base.GetAPPDocNullReturnFromArg(args) + if err == nil && len(docNullReturn) > 0 { + return args, &base.RetrieverGetNullDocError{Msg: docNullReturn} + } + } + // replace, not append + args[base.LangchaingoRetrieversKeyInArg] = []langchainschema.Retriever{&Fakeretriever{Name: "mergerRetriever", Docs: docs}} + return args, nil +} + +func (l *MergerRetriever) Ready() (isReady bool, msg string) { + return l.Instance.Status.IsReadyOrGetReadyMessage() +} diff --git a/pkg/appruntime/retriever/multiqueryretriever.go b/pkg/appruntime/retriever/multiqueryretriever.go index 610ec9c1f..0d5cc4ac1 100644 --- a/pkg/appruntime/retriever/multiqueryretriever.go +++ b/pkg/appruntime/retriever/multiqueryretriever.go @@ -69,13 +69,9 @@ func (l *MultiQueryRetriever) Run(ctx context.Context, cli client.Client, args m return args, errors.New("empty question") } - v1, ok := args[base.LangchaingoRetrieverKeyInArg] - if !ok { - return args, errors.New("no retriever") - } - retriever, ok := v1.(langchainschema.Retriever) - if !ok { - return args, errors.New("retriever not schema.Retriever") + retrieversInArg, err := base.GetRetrieversFromArg(args) + if err != nil { + return args, err } v2, ok := args[base.LangchaingoLLMKeyInArg] @@ -88,7 +84,7 @@ func (l *MultiQueryRetriever) Run(ctx context.Context, cli client.Client, args m } prompt := prompts.NewPromptTemplate(_defaultQueryTemplate, []string{"question"}) llmchain := chains.NewLLMChain(llm, prompt, chains.WithCallback(log.KLogHandler{LogLevel: 3})) - multiqueryRetriever := retrievers.NewMultiQueryRetriever(retriever, llmchain, true) + multiqueryRetriever := retrievers.NewMultiQueryRetriever(retrieversInArg[0], llmchain, true) multiqueryRetriever.CallbacksHandler = log.KLogHandler{LogLevel: 3} docs, err := multiqueryRetriever.GetRelevantDocuments(ctx, query) if err != nil { @@ -107,7 +103,7 @@ func (l *MultiQueryRetriever) Run(ctx context.Context, cli client.Client, args m newDocs, newRef := ConvertDocuments(ctx, newDocs, "multiquery") // note: the references in args will be replaced, not append args[base.RuntimeRetrieverReferencesKeyInArg] = newRef - args[base.LangchaingoRetrieverKeyInArg] = &Fakeretriever{Docs: newDocs, Name: "MultiqueryRetriever"} + args[base.LangchaingoRetrieversKeyInArg] = &Fakeretriever{Docs: newDocs, Name: "MultiqueryRetriever"} return args, nil } diff --git a/pkg/appruntime/retriever/rerankretriever.go b/pkg/appruntime/retriever/rerankretriever.go index 4ae0c1fe6..0cbbd3a64 100644 --- a/pkg/appruntime/retriever/rerankretriever.go +++ b/pkg/appruntime/retriever/rerankretriever.go @@ -73,7 +73,7 @@ func (l *RerankRetriever) Run(ctx context.Context, cli client.Client, args map[s return nil, &base.RetrieverGetNullDocError{Msg: docNullReturn} } } - args[base.LangchaingoRetrieverKeyInArg] = &Fakeretriever{Docs: nil, Name: "RerankRetriever"} + args[base.LangchaingoRetrieversKeyInArg] = &Fakeretriever{Docs: nil, Name: "RerankRetriever"} return args, nil } q, ok := args[base.InputQuestionKeyInArg] @@ -143,15 +143,11 @@ func (l *RerankRetriever) Run(ctx context.Context, cli client.Client, args map[s // note: the references in args will be replaced, not append args[base.RuntimeRetrieverReferencesKeyInArg] = newRef - v, ok := args[base.LangchaingoRetrieverKeyInArg] - if !ok { - return args, errors.New("no retriever") - } - retriever, ok := v.(langchainschema.Retriever) - if !ok { - return args, errors.New("retriever not schema.Retriever") + retrievers, err := base.GetRetrieversFromArg(args) + if err != nil { + return args, err } - docs, err := retriever.GetRelevantDocuments(ctx, query) + docs, err := retrievers[0].GetRelevantDocuments(ctx, query) if err != nil { return args, fmt.Errorf("get relevant documents failed: %w", err) } @@ -163,7 +159,7 @@ func (l *RerankRetriever) Run(ctx context.Context, cli client.Client, args map[s } } } - args[base.LangchaingoRetrieverKeyInArg] = &Fakeretriever{Docs: newDocs, Name: "RerankRetriever"} + args[base.LangchaingoRetrieversKeyInArg] = &Fakeretriever{Docs: newDocs, Name: "RerankRetriever"} return args, nil }