Skip to content

Commit

Permalink
Merge pull request #979 from bjwswang/main
Browse files Browse the repository at this point in the history
Couple fixes for RAG evaluation
  • Loading branch information
bjwswang committed Apr 3, 2024
2 parents 4c596d9 + 616f9b6 commit 6b7d139
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 36 deletions.
2 changes: 2 additions & 0 deletions api/base/v1alpha1/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ func (worker Worker) BuildEmbedder() *Embedder {
Type: embeddings.OpenAI,
Provider: Provider{
Worker: &TypedObjectReference{
APIGroup: pointer.String(GroupVersion.String()),
Kind: "Worker",
Namespace: &worker.Namespace,
Name: worker.Name,
Expand All @@ -188,6 +189,7 @@ func (worker Worker) BuildLLM() *LLM {
Type: llms.OpenAI,
Provider: Provider{
Worker: &TypedObjectReference{
APIGroup: pointer.String(GroupVersion.String()),
Kind: "Worker",
Namespace: &worker.Namespace,
Name: worker.Name,
Expand Down
85 changes: 49 additions & 36 deletions apiserver/pkg/application/application.go
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat
}

func mutateApp(app *v1alpha1.Application, input generated.UpdateApplicationConfigInput, hasMultiQueryRetriever, hasRerankRetriever bool) error {
app.Spec.Nodes = redefineNodes(input.Knowledgebase, input.Name, input.Llm, input.Tools, hasMultiQueryRetriever, hasRerankRetriever, input.EnableUploadFile)
app.Spec.Nodes = redefineNodes(input.Knowledgebase, input.Namespace, input.Name, input.Llm, input.Tools, hasMultiQueryRetriever, hasRerankRetriever, input.EnableUploadFile)
app.Spec.Prologue = pointer.StringDeref(input.Prologue, app.Spec.Prologue)
app.Spec.ShowRespInfo = pointer.BoolDeref(input.ShowRespInfo, app.Spec.ShowRespInfo)
app.Spec.ShowRetrievalInfo = pointer.BoolDeref(input.ShowRetrievalInfo, app.Spec.ShowRetrievalInfo)
Expand All @@ -755,16 +755,18 @@ func mutateApp(app *v1alpha1.Application, input generated.UpdateApplicationConfi
return nil
}

func redefineNodes(knowledgebase *string, name string, llmName string, tools []*generated.ToolInput, hasMultiQueryRetriever, hasRerankRetriever bool, enableUploadFile *bool) (nodes []v1alpha1.Node) {
// redefineNodes redefine nodes in application
func redefineNodes(knowledgebase *string, namespace string, name string, llmName string, tools []*generated.ToolInput, hasMultiQueryRetriever, hasRerankRetriever bool, enableUploadFile *bool) (nodes []v1alpha1.Node) {
nodes = []v1alpha1.Node{
{
NodeConfig: v1alpha1.NodeConfig{
Name: "Input",
DisplayName: "用户输入",
Description: "用户输入节点,必须",
Ref: &v1alpha1.TypedObjectReference{
Kind: "Input",
Name: "Input",
Kind: "Input",
Name: "Input",
Namespace: &namespace,
},
},
NextNodeName: []string{"prompt-node"},
Expand All @@ -775,9 +777,10 @@ func redefineNodes(knowledgebase *string, name string, llmName string, tools []*
DisplayName: "prompt",
Description: "设定prompt,template中可以使用{{.}}来替换变量",
Ref: &v1alpha1.TypedObjectReference{
APIGroup: pointer.String("prompt.arcadia.kubeagi.k8s.com.cn"),
Kind: "Prompt",
Name: name,
APIGroup: pointer.String("prompt.arcadia.kubeagi.k8s.com.cn"),
Kind: "Prompt",
Name: name,
Namespace: &namespace,
},
},
NextNodeName: []string{"chain-node"},
Expand All @@ -790,9 +793,10 @@ func redefineNodes(knowledgebase *string, name string, llmName string, tools []*
DisplayName: "documentloader",
Description: "文档加载,可选",
Ref: &v1alpha1.TypedObjectReference{
APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"),
Kind: "DocumentLoader",
Name: name,
APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"),
Kind: "DocumentLoader",
Name: name,
Namespace: &namespace,
},
},
NextNodeName: []string{"chain-node"},
Expand All @@ -804,9 +808,10 @@ func redefineNodes(knowledgebase *string, name string, llmName string, tools []*
DisplayName: "llm",
Description: "设定大模型的访问信息",
Ref: &v1alpha1.TypedObjectReference{
APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"),
Kind: "LLM",
Name: llmName,
APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"),
Kind: "LLM",
Name: llmName,
Namespace: &namespace,
},
},
NextNodeName: []string{"chain-node"},
Expand All @@ -821,9 +826,10 @@ func redefineNodes(knowledgebase *string, name string, llmName string, tools []*
DisplayName: "llm chain",
Description: "chain是langchain的核心概念,llmChain用于连接prompt和llm",
Ref: &v1alpha1.TypedObjectReference{
APIGroup: pointer.String("chain.arcadia.kubeagi.k8s.com.cn"),
Kind: "LLMChain",
Name: name,
APIGroup: pointer.String("chain.arcadia.kubeagi.k8s.com.cn"),
Kind: "LLMChain",
Name: name,
Namespace: &namespace,
},
},
NextNodeName: []string{"Output"},
Expand All @@ -836,9 +842,10 @@ func redefineNodes(knowledgebase *string, name string, llmName string, tools []*
DisplayName: "知识库",
Description: "连接知识库",
Ref: &v1alpha1.TypedObjectReference{
APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"),
Kind: "KnowledgeBase",
Name: pointer.StringDeref(knowledgebase, ""),
APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"),
Kind: "KnowledgeBase",
Name: pointer.StringDeref(knowledgebase, ""),
Namespace: &namespace,
},
},
NextNodeName: []string{"retriever-node"},
Expand All @@ -849,9 +856,10 @@ func redefineNodes(knowledgebase *string, name string, llmName string, tools []*
DisplayName: "从知识库提取信息的retriever",
Description: "连接应用和知识库",
Ref: &v1alpha1.TypedObjectReference{
APIGroup: pointer.String("retriever.arcadia.kubeagi.k8s.com.cn"),
Kind: "KnowledgeBaseRetriever",
Name: name,
APIGroup: pointer.String("retriever.arcadia.kubeagi.k8s.com.cn"),
Kind: "KnowledgeBaseRetriever",
Name: name,
Namespace: &namespace,
},
},
NextNodeName: []string{"chain-node"},
Expand All @@ -878,9 +886,10 @@ func redefineNodes(knowledgebase *string, name string, llmName string, tools []*
DisplayName: "多查询retriever",
Description: "多查询retriever",
Ref: &v1alpha1.TypedObjectReference{
APIGroup: pointer.String("retriever.arcadia.kubeagi.k8s.com.cn"),
Kind: "MultiQueryRetriever",
Name: name,
APIGroup: pointer.String("retriever.arcadia.kubeagi.k8s.com.cn"),
Kind: "MultiQueryRetriever",
Name: name,
Namespace: &namespace,
},
},
NextNodeName: []string{nextNodeName},
Expand All @@ -894,9 +903,10 @@ func redefineNodes(knowledgebase *string, name string, llmName string, tools []*
DisplayName: "rerank retriever",
Description: "rerank retriever",
Ref: &v1alpha1.TypedObjectReference{
APIGroup: pointer.String("retriever.arcadia.kubeagi.k8s.com.cn"),
Kind: "RerankRetriever",
Name: name,
APIGroup: pointer.String("retriever.arcadia.kubeagi.k8s.com.cn"),
Kind: "RerankRetriever",
Name: name,
Namespace: &namespace,
},
},
NextNodeName: []string{"chain-node"},
Expand All @@ -909,9 +919,10 @@ func redefineNodes(knowledgebase *string, name string, llmName string, tools []*
DisplayName: "RetrievalQA chain",
Description: "chain是langchain的核心概念RetrievalQAChain用于从retriever中提取信息,供llm调用",
Ref: &v1alpha1.TypedObjectReference{
APIGroup: pointer.String("chain.arcadia.kubeagi.k8s.com.cn"),
Kind: "RetrievalQAChain",
Name: name,
APIGroup: pointer.String("chain.arcadia.kubeagi.k8s.com.cn"),
Kind: "RetrievalQAChain",
Name: name,
Namespace: &namespace,
},
},
NextNodeName: []string{"Output"},
Expand All @@ -924,9 +935,10 @@ func redefineNodes(knowledgebase *string, name string, llmName string, tools []*
DisplayName: "agent",
Description: "agent 调用复杂工具完成任务",
Ref: &v1alpha1.TypedObjectReference{
APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"),
Kind: "Agent",
Name: name,
APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"),
Kind: "Agent",
Name: name,
Namespace: &namespace,
},
},
NextNodeName: []string{"chain-node"},
Expand All @@ -938,8 +950,9 @@ func redefineNodes(knowledgebase *string, name string, llmName string, tools []*
DisplayName: "最终输出",
Description: "最终输出节点,必须",
Ref: &v1alpha1.TypedObjectReference{
Kind: "Output",
Name: "Output",
Kind: "Output",
Name: "Output",
Namespace: &namespace,
},
},
})
Expand Down
4 changes: 4 additions & 0 deletions apiserver/pkg/knowledgebase/knowledgebase.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ func knowledgebase2model(ctx context.Context, c client.Client, knowledgebase *v1
}

source := &generated.TypedObjectReference{
APIGroup: fg.Source.APIGroup,
Kind: fg.Source.Kind,
Name: fg.Source.Name,
Namespace: new(string),
Expand Down Expand Up @@ -122,6 +123,7 @@ func knowledgebase2model(ctx context.Context, c client.Client, knowledgebase *v1

embedderResource := &v1alpha1.Embedder{}
embedder := generated.TypedObjectReference{
APIGroup: knowledgebase.Spec.Embedder.APIGroup,
Kind: knowledgebase.Spec.Embedder.Kind,
Name: knowledgebase.Spec.Embedder.Name,
Namespace: knowledgebase.Spec.Embedder.Namespace,
Expand Down Expand Up @@ -156,6 +158,7 @@ func knowledgebase2model(ctx context.Context, c client.Client, knowledgebase *v1
EmbedderType: &embedderType,
// Vector info
VectorStore: &generated.TypedObjectReference{
APIGroup: knowledgebase.Spec.VectorStore.APIGroup,
Kind: knowledgebase.Spec.VectorStore.Kind,
Name: knowledgebase.Spec.VectorStore.Name,
Namespace: knowledgebase.Spec.VectorStore.Namespace,
Expand Down Expand Up @@ -246,6 +249,7 @@ func CreateKnowledgeBase(ctx context.Context, c client.Client, input generated.C
Description: description,
},
Embedder: &v1alpha1.TypedObjectReference{
APIGroup: pointer.String(v1alpha1.GroupVersion.String()),
Kind: "Embedder",
Name: embedder,
Namespace: &input.Namespace,
Expand Down
3 changes: 3 additions & 0 deletions apiserver/pkg/versioneddataset/versioned_dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ func UpdateVersionedDataset(ctx context.Context, c client.Client, input *generat
for _, item := range input.FileGroups {
tmp := v1alpha1.FileGroup{
Source: &v1alpha1.TypedObjectReference{
APIGroup: item.Source.APIGroup,
Kind: item.Source.Kind,
Name: item.Source.Name,
Namespace: item.Source.Namespace,
Expand Down Expand Up @@ -312,6 +313,7 @@ func CreateVersionedDataset(ctx context.Context, c client.Client, input *generat
vds.Spec = v1alpha1.VersionedDatasetSpec{
Version: input.Version,
Dataset: &v1alpha1.TypedObjectReference{
APIGroup: pointer.String(v1alpha1.GroupVersion.String()),
Kind: "Dataset",
Name: input.DatasetName,
Namespace: &input.Namespace,
Expand All @@ -328,6 +330,7 @@ func CreateVersionedDataset(ctx context.Context, c client.Client, input *generat
for _, item := range input.FileGrups {
tmp := v1alpha1.FileGroup{
Source: &v1alpha1.TypedObjectReference{
APIGroup: item.Source.APIGroup,
Kind: item.Source.Kind,
Name: item.Source.Name,
Namespace: item.Source.Namespace,
Expand Down
1 change: 1 addition & 0 deletions apiserver/pkg/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ func CreateWorker(ctx context.Context, c client.Client, input generated.CreateWo
},
Type: workerType,
Model: &v1alpha1.TypedObjectReference{
APIGroup: pointer.String(v1alpha1.GroupVersion.String()),
Name: input.Model.Name,
Namespace: &modelNs,
Kind: "Model",
Expand Down
5 changes: 5 additions & 0 deletions deploy/charts/arcadia/templates/rag-rbac.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ rules:
- knowledgebases
- embedders
- vectorstores
- documentloaders
- agents
verbs:
- get
- list
Expand All @@ -36,6 +38,7 @@ rules:
resources:
- llmchains
- retrievalqachains
- apichains
verbs:
- get
- list
Expand All @@ -50,6 +53,8 @@ rules:
- retriever.arcadia.kubeagi.k8s.com.cn
resources:
- knowledgebaseretrievers
- multiqueryretrievers
- rerankretrievers
verbs:
- list
- get
Expand Down
5 changes: 5 additions & 0 deletions pkg/versioneddataset/versioneddataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
corev1 "k8s.io/api/core/v1"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/klog/v2"
"k8s.io/utils/pointer"

"github.com/kubeagi/arcadia/api/base/v1alpha1"
"github.com/kubeagi/arcadia/pkg/datasource"
Expand Down Expand Up @@ -61,6 +62,7 @@ func generateInheritedFileStatus(oss *datasource.OSS, instance *v1alpha1.Version
return []v1alpha1.FileStatus{
{
TypedObjectReference: v1alpha1.TypedObjectReference{
APIGroup: pointer.String(v1alpha1.GroupVersion.String()),
Name: name,
Namespace: &instance.Namespace,
Kind: "VersionedDataset",
Expand Down Expand Up @@ -93,6 +95,7 @@ func generateDatasourceFileStatus(instance *v1alpha1.VersionedDataset) []v1alpha
_, _ = fmt.Sscanf(datasource, "%s %s", &namespace, &name)
item := v1alpha1.FileStatus{
TypedObjectReference: v1alpha1.TypedObjectReference{
APIGroup: pointer.String(v1alpha1.GroupVersion.String()),
Name: name,
Namespace: &namespace,
Kind: "Datasource",
Expand Down Expand Up @@ -168,6 +171,8 @@ func CopiedFileGroup2Status(oss *datasource.OSS, instance *v1alpha1.VersionedDat
if len(datasourceFiles) > 0 {
ds := v1alpha1.FileStatus{
TypedObjectReference: v1alpha1.TypedObjectReference{
APIGroup: item.APIGroup,
Kind: item.Kind,
Name: item.Name,
Namespace: item.Namespace,
},
Expand Down

0 comments on commit 6b7d139

Please sign in to comment.