Skip to content

Commit

Permalink
feat: add a new mapreducechain in appruntime
Browse files Browse the repository at this point in the history
Signed-off-by: bjwswang <bjwswang@gmail.com>
  • Loading branch information
bjwswang committed Feb 23, 2024
1 parent 924a43f commit 0d4cb7e
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 66 deletions.
71 changes: 30 additions & 41 deletions apiserver/pkg/chat/chat_docs.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@ import (
"sync"
"time"

"github.com/tmc/langchaingo/chains"
"github.com/tmc/langchaingo/documentloaders"
langchainllms "github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/prompts"
"github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/textsplitter"
"k8s.io/apimachinery/pkg/util/uuid"
Expand All @@ -40,7 +38,7 @@ import (
"github.com/kubeagi/arcadia/apiserver/pkg/chat/storage"
"github.com/kubeagi/arcadia/apiserver/pkg/common"
runtimebase "github.com/kubeagi/arcadia/pkg/appruntime/base"
runtimechains "github.com/kubeagi/arcadia/pkg/appruntime/chain"
runtimechain "github.com/kubeagi/arcadia/pkg/appruntime/chain"
runtimellm "github.com/kubeagi/arcadia/pkg/appruntime/llm"
"github.com/kubeagi/arcadia/pkg/langchainwrap"
"github.com/kubeagi/arcadia/pkg/utils"
Expand Down Expand Up @@ -233,7 +231,7 @@ func (cs *ChatServer) SummarizeConversationDoc(ctx context.Context, req Conversa
<-semaphore
}()
klog.V(5).Infof("Generate summarization from file %s for conversation %s", doc.Filename, req.ConversationID)
summary, errSummary = cs.GenerateSingleDocSummary(ctx, req, doc, documents, respStream)
summary, errSummary = cs.GenerateSingleDocSummary(ctx, req, documents, respStream)
if errSummary != nil {
// break once error occurs
errStr += fmt.Sprintf(" ErrSummary: %s", errSummary.Error())
Expand Down Expand Up @@ -278,14 +276,14 @@ func (cs *ChatServer) GenerateSingleDocEmbeddings(ctx context.Context, req Conve
}

// GenerateSingleDocSummary generate the summary of single document
func (cs *ChatServer) GenerateSingleDocSummary(ctx context.Context, req ConversationDocsReqBody, doc *multipart.FileHeader, documents []schema.Document, respStream chan string) (string, error) {
func (cs *ChatServer) GenerateSingleDocSummary(ctx context.Context, req ConversationDocsReqBody, documents []schema.Document, respStream chan string) (string, error) {
app, c, err := cs.getApp(ctx, req.APPName, req.AppNamespace)
if err != nil {
return "", fmt.Errorf("failed to get app due to %s", err.Error())
}

var llm langchainllms.LLM
chainCallOptions := make([]chains.ChainCallOption, 0)
var mpChainNode runtimebase.BaseNode
// find LLM along with chain call options
for _, n := range app.Spec.Nodes {
baseNode := runtimebase.NewBaseNode(app.Namespace, n.Name, *n.Ref)
Expand All @@ -297,50 +295,41 @@ func (cs *ChatServer) GenerateSingleDocSummary(ctx context.Context, req Conversa
}
llm = l.LLM
case "llmchain":
llmchain := runtimechains.NewLLMChain(baseNode)
if err := llmchain.Init(ctx, cs.cli, nil); err != nil {
return "", err
}
chainCallOptions = runtimechains.GetChainOptions(llmchain.Instance.Spec.CommonChainConfig)
mpChainNode = baseNode
case "retrievalqachain":
retrivalQAChain := runtimechains.NewRetrievalQAChain(baseNode)
if err := retrivalQAChain.Init(ctx, cs.cli, nil); err != nil {
return "", err
}
chainCallOptions = runtimechains.GetChainOptions(retrivalQAChain.Instance.Spec.CommonChainConfig)
mpChainNode = baseNode
}
}

// If no LLM provided,we can't generate the summary
if llm == nil {
return "", ErrNoLLMProvidedInApplication
}

// initialize a MapReduceChain
mpChain := chains.NewMapReduceDocuments(
chains.NewLLMChain(llm, prompts.NewPromptTemplate(DefaultPromptTemplateForMap, []string{"context"})),
chains.NewStuffDocuments(
chains.NewLLMChain(
llm,
prompts.NewPromptTemplate(DefaultPromptTemplatForReduce, []string{"context"}),
),
),
)

// concurrent api call
mpChain.MaxNumberOfConcurrent = DefaultSummaryMaxNumberOfConcurrent

var summary string
if req.ResponseMode.IsStreaming() {
chainCallOptions = append(chainCallOptions, chains.WithStreamingFunc(func(ctx context.Context, chunk []byte) error {
respStream <- string(chunk)
return nil
}))
out := map[string]any{
"question": req.Query,
"_answer_stream": respStream,
"llm": llm,
"documents": documents,
}
if req.ResponseMode == "streaming" {
out["_need_stream"] = true
}
// initialize MapReduceChain
mpChain := runtimechain.NewMapReduceChain(mpChainNode)
if err = mpChain.Init(ctx, cs.cli, out); err != nil {
return "", err
}
summary, err = chains.Run(ctx, mpChain, documents, chainCallOptions...)
out, err = mpChain.Run(ctx, cs.cli, out)
if err != nil {
return "", fmt.Errorf("failed to generate summary for %s due to %s", doc.Filename, err.Error())
return "", fmt.Errorf("failed to generate summary due to %s", err.Error())
}

return summary, nil
a, ok := out["_answer"]
if !ok {
return "", errors.New("empty answer")
}
answer, ok := a.(string)
if !ok && len(answer) > 0 {
return "", errors.New("invalid answer.not a string")
}
return answer, nil
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -181,4 +181,4 @@ require (
sigs.k8s.io/yaml v1.3.0 // indirect
)

replace github.com/tmc/langchaingo => github.com/Abirdcfly/langchaingo v0.0.0-20240124015404-c7798664fdb1 // branch arcadia
replace github.com/tmc/langchaingo => github.com/kubeagi/langchaingo v0.0.0-20240223090005-71cb3f753545 // branch arcadia
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
github.com/99designs/gqlgen v0.17.40 h1:/l8JcEVQ93wqIfmH9VS1jsAkwm6eAF1NwQn3N+SDqBY=
github.com/99designs/gqlgen v0.17.40/go.mod h1:b62q1USk82GYIVjC60h02YguAZLqYZtvWml8KkhJps4=
github.com/Abirdcfly/langchaingo v0.0.0-20240124015404-c7798664fdb1 h1:CtCYyn/1c4Dqyk7WSMWUzOAjG1tZSj4JQpAnmhx7b8k=
github.com/Abirdcfly/langchaingo v0.0.0-20240124015404-c7798664fdb1/go.mod h1:vOFzX91wqTXvirejd6xjPXSmGn8yYKHt/FunAgrOBmI=
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/Azure/go-autorest v14.2.0+incompatible h1:V5VMDjClD3GiElqLWO7mz2MxNAK/vTfRHdAubSIPRgs=
github.com/Azure/go-autorest v14.2.0+incompatible/go.mod h1:r+4oMnoxhatjLLJ6zxSWATqVooLgysK6ZNox3g/xq24=
Expand Down Expand Up @@ -484,6 +482,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kubeagi/langchaingo v0.0.0-20240223090005-71cb3f753545 h1:eCwU61MXxpWBDJUPBYeN0cTIj4Vy9QtNPn6r7LtKmV4=
github.com/kubeagi/langchaingo v0.0.0-20240223090005-71cb3f753545/go.mod h1:vOFzX91wqTXvirejd6xjPXSmGn8yYKHt/FunAgrOBmI=
github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80 h1:6Yzfa6GP0rIo/kULo2bwGEkFvCePZ3qHDDTC3/J9Swo=
github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80/go.mod h1:imJHygn/1yfhB7XSJJKlFZKl/J+dCPAknuiaGOshXAs=
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
Expand Down
23 changes: 1 addition & 22 deletions pkg/appruntime/app_runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,41 +59,20 @@ type Application struct {
EndingNode base.Node
}

// var cache = map[string]*Application{}

// func cacheKey(app *arcadiav1alpha1.Application) string {
// return app.Namespace + "/" + app.Name
//}

func NewAppOrGetFromCache(ctx context.Context, cli client.Client, app *arcadiav1alpha1.Application) (*Application, error) {
if app == nil || app.Name == "" || app.Namespace == "" {
return nil, errors.New("app has no name or namespace")
}
// TODO: disable cache for now.
// https://github.com/kubeagi/arcadia/issues/391
// a, ok := cache[cacheKey(app)]
// if !ok {
// a = &Application{
// Spec: app.Spec,
// }
// cache[cacheKey(app)] = a
// return a, a.Init(ctx, cli)
// }
// if reflect.DeepEqual(a.Spec, app.Spec) {
// return a, nil
// }
a := &Application{
Namespace: app.GetNamespace(),
Name: app.Name,
Spec: app.Spec,
Inited: false,
}
// a.Spec = app.Spec
// a.Inited = false
return a, a.Init(ctx, cli)
}

// todo 防止无限循环,需要找一下是不是成环
// TODO: 防止无限循环,需要找一下是不是成环
func (a *Application) Init(ctx context.Context, cli client.Client) (err error) {
if a.Inited {
return
Expand Down
186 changes: 186 additions & 0 deletions pkg/appruntime/chain/mpchain.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
/*
Copyright 2024 KubeAGI.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package chain

import (
"context"
"errors"
"fmt"

"github.com/tmc/langchaingo/chains"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/prompts"
"github.com/tmc/langchaingo/schema"
"k8s.io/klog/v2"
"sigs.k8s.io/controller-runtime/pkg/client"

"github.com/kubeagi/arcadia/pkg/appruntime/base"
)

const (
// For map-reduce
DefaultPromptTemplateForMap = `
{{.context}}
With above content, please summarize it with only half content size of it.
`
DefaultPromptTemplatForReduce = `"{{.context}}"`

// For post process the map-reduced summary
DefaultPromptTemplateForPostMapReduce = `
Here is the map-reduced summary of a document:
Summary: {{.summary}}
Now please answer the following question based on the above document summary. Make sure the answer is using same language with the question:
Question: {{.question}}
Answer:
`

DefaultSummaryMaxNumberOfConcurrent = 2
DefaultDocumentChunkSize = 1024
DefaultDocumentChunkOverlap = 100
)

type MapReduceChain struct {
// BaseNode for this MapReduceChain
// Only chain is allowed
base.BaseNode

// isReady indicates whether this chain is ready to use
isReady bool
// message indicates the detailed message of ready/not ready
message string

// MapReduceDocuments used to generate summary
chains.MapReduceDocuments
// LLMChain used to
chains.LLMChain

// call options against llm
chainCallOptions []chains.ChainCallOption
}

func NewMapReduceChain(baseNode base.BaseNode) *MapReduceChain {
return &MapReduceChain{
BaseNode: baseNode,
MapReduceDocuments: chains.MapReduceDocuments{},
}
}

func (l *MapReduceChain) Init(ctx context.Context, cli client.Client, args map[string]any) error {
if args == nil {
return errors.New("no arguments provided for MapReduceChain")
}
// initialize the LLM
v1, ok := args["llm"]
if !ok {
return errors.New("no llm")
}
llm, ok := v1.(llms.LLM)
if !ok {
return errors.New("llm not llms.LLM")
}

// only group `chain` is allowed
if l.BaseNode.Group() != "chain" {
return fmt.Errorf("invalid base node with group %s.must be in group chain", l.BaseNode.Group())
}
// initialize call options
var chainCallOptions []chains.ChainCallOption
switch kind := l.BaseNode.Kind(); kind {
case "llmchain":
llmchain := NewLLMChain(l.BaseNode)
if err := llmchain.Init(ctx, cli, nil); err != nil {
return err
}
l.isReady, l.message = llmchain.Ready()
if !l.isReady {
return fmt.Errorf("llmchain is not ready with %s", l.message)
}
chainCallOptions = GetChainOptions(llmchain.Instance.Spec.CommonChainConfig)
case "retrievalqachain":
retrivalQAChain := NewRetrievalQAChain(l.BaseNode)
if err := retrivalQAChain.Init(ctx, cli, nil); err != nil {
return err
}
l.isReady, l.message = retrivalQAChain.Ready()
if !l.isReady {
return fmt.Errorf("retrivalQAChain is not ready with %s", l.message)
}
chainCallOptions = GetChainOptions(retrivalQAChain.Instance.Spec.CommonChainConfig)
default:
return fmt.Errorf("invalid base node kind %s for MapReduceChain.not supported yet", kind)
}
l.chainCallOptions = append(l.chainCallOptions, chainCallOptions...)

// initialize MapReduceDocuments
l.MapReduceDocuments = chains.NewMapReduceDocuments(
chains.NewLLMChain(llm, prompts.NewPromptTemplate(DefaultPromptTemplateForMap, []string{"context"})),
chains.NewStuffDocuments(
chains.NewLLMChain(
llm,
prompts.NewPromptTemplate(DefaultPromptTemplatForReduce, []string{"context"}),
),
),
)

l.LLMChain = *chains.NewLLMChain(llm, prompts.NewPromptTemplate(DefaultPromptTemplateForPostMapReduce, []string{"summary", "question"}))

return nil
}

func (l *MapReduceChain) Run(ctx context.Context, cli client.Client, args map[string]any) (outArgs map[string]any, err error) {
v1, ok := args["documents"]
if !ok {
return args, errors.New("no documents")
}
documents, ok := v1.([]schema.Document)
if !ok {
return args, errors.New("llm not llms.LanguageModel")
}
// run MapReduceDocuments
out, err := chains.Run(ctx, l.MapReduceDocuments, documents, l.chainCallOptions...)
if err != nil {
return args, fmt.Errorf("failed to run MapReduceChain due to %s", err.Error())
}
// set the summary with the output of MapReduceDocuments
args["summary"] = out

// run LLMChain
needStream := false
needStream, ok = args["_need_stream"].(bool)
if ok && needStream {
l.chainCallOptions = append(l.chainCallOptions, chains.WithStreamingFunc(stream(args)))
}
// call llmchain
out, err = chains.Predict(ctx, l.LLMChain, args, l.chainCallOptions...)
// handler out & error
out, err = handleNoErrNoOut(ctx, needStream, out, err, l.LLMChain, args, l.chainCallOptions)
klog.FromContext(ctx).V(5).Info("use MapReduceChain, blocking out:" + out)
if err == nil {
args["_answer"] = out
return args, nil
}
return args, fmt.Errorf("mapreaducechain run error: %w", err)
}

func (l *MapReduceChain) Ready() (bool, string) {
return l.isReady, l.message
}

0 comments on commit 0d4cb7e

Please sign in to comment.