Skip to content

Commit

Permalink
feat: add gemini llm, use it in tests
Browse files Browse the repository at this point in the history
Signed-off-by: Abirdcfly <fp544037857@gmail.com>
  • Loading branch information
Abirdcfly committed Feb 21, 2024
1 parent 11945d7 commit 714bfb8
Show file tree
Hide file tree
Showing 21 changed files with 163 additions and 9 deletions.
2 changes: 2 additions & 0 deletions api/base/v1alpha1/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ func (e Embedder) Get3rdPartyModels() []string {
return embeddings.ZhiPuAIModels
case embeddings.OpenAI:
return embeddings.OpenAIModels
case embeddings.Gemini:
return embeddings.GeminiModels
}

return []string{}
Expand Down
2 changes: 2 additions & 0 deletions api/base/v1alpha1/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ func (llm LLM) Get3rdPartyModels() []string {
return llms.ZhiPuAIModels
case llms.OpenAI:
return llms.OpenAIModels
case llms.Gemini:
return llms.GeminiModels
}
return []string{}
}
Expand Down
1 change: 1 addition & 0 deletions config/samples/app_llmchain_englishteacher.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,4 @@ spec:
description: "llm chain"
memory:
maxTokenLimit: 20480
model: chatglm_turbo
1 change: 1 addition & 0 deletions config/samples/app_retrievalqachain_knowledgebase.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ spec:
description: "用于搜索QA的Chain"
memory:
maxTokenLimit: 20480
model: chatglm_turbo
---
apiVersion: retriever.arcadia.kubeagi.k8s.com.cn/v1alpha1
kind: KnowledgeBaseRetriever
Expand Down
22 changes: 22 additions & 0 deletions config/samples/app_shared_llm_service_gemini.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
apiVersion: v1
kind: Secret
metadata:
name: app-shared-llm-secret
namespace: arcadia
type: Opaque
data:
apiKey: "QUl6YVN5QVZOdGRYOHpkeU5pNWpubzNYSExUWGM0UnpJSGxIRUFz"
---
apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1
kind: LLM
metadata:
name: app-shared-llm-service
namespace: arcadia
spec:
type: "gemini"
provider:
endpoint:
url: "https://generativelanguage.googleapis.com/"
authSecret:
kind: secret
name: app-shared-llm-secret
File renamed without changes.
22 changes: 22 additions & 0 deletions config/samples/arcadia_v1alpha1_embedders_gemini.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
apiVersion: v1
kind: Secret
metadata:
name: gemini
namespace: arcadia
type: Opaque
data:
apiKey: "QUl6YVN5QVZOdGRYOHpkeU5pNWpubzNYSExUWGM0UnpJSGxIRUFz"
---
apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1
kind: Embedder
metadata:
name: embedders-sample
namespace: arcadia
spec:
type: "gemini"
provider:
endpoint:
url: "https://generativelanguage.googleapis.com/"
authSecret:
kind: secret
name: gemini
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ data:
apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1
kind: Embedder
metadata:
name: zhipuai-embedders-sample
name: embedders-sample
namespace: arcadia
spec:
type: "zhipuai"
Expand Down
2 changes: 1 addition & 1 deletion config/samples/arcadia_v1alpha1_knowledgebase.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ spec:
description: "测试 KnowledgeBase"
embedder:
kind: Embedders
name: zhipuai-embedders-sample
name: embedders-sample
namespace: arcadia
vectorStore:
kind: VectorStores
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ spec:
description: "测试 KnowledgeBase"
embedder:
kind: Embedders
name: zhipuai-embedders-sample
name: embedders-sample
namespace: arcadia
vectorStore:
kind: VectorStores
Expand Down
22 changes: 22 additions & 0 deletions controllers/base/embedder_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (

"github.com/go-logr/logr"
langchainembeddings "github.com/tmc/langchaingo/embeddings"
"github.com/tmc/langchaingo/llms/googleai"
langchainopenai "github.com/tmc/langchaingo/llms/openai"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/types"
Expand Down Expand Up @@ -227,6 +228,27 @@ func (r *EmbedderReconciler) check3rdPartyEmbedder(ctx context.Context, logger l
}
msg = "Success"
}
case embeddings.Gemini:
// validate all embedding models
for _, model := range models {
llm, err := googleai.New(
ctx,
googleai.WithAPIKey(apiKey),
googleai.WithDefaultEmbeddingModel(model),
)
if err != nil {
return r.UpdateStatus(ctx, instance, nil, err)
}
embedClient, err := langchainembeddings.NewEmbedder(llm)
if err != nil {
return r.UpdateStatus(ctx, instance, nil, err)
}
_, err = embedClient.EmbedQuery(ctx, embedingText)
if err != nil {
return r.UpdateStatus(ctx, instance, nil, err)
}
msg = "Success"
}
default:
return r.UpdateStatus(ctx, instance, nil, fmt.Errorf("unsupported service type: %s", instance.Spec.Type))
}
Expand Down
14 changes: 14 additions & 0 deletions controllers/base/llm_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (

"github.com/go-logr/logr"
langchainllms "github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/googleai"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/types"
ctrl "sigs.k8s.io/controller-runtime"
Expand Down Expand Up @@ -215,6 +216,19 @@ func (r *LLMReconciler) check3rdPartyLLM(ctx context.Context, logger logr.Logger
}
msg = strings.Join([]string{msg, res.String()}, "\n")
}
case llms.Gemini:
llmClient, err := googleai.New(ctx, googleai.WithAPIKey(apiKey))
if err != nil {
return r.UpdateStatus(ctx, instance, nil, err)
}
// validate against models
for _, model := range models {
res, err := llmClient.Call(ctx, "Hello", langchainllms.WithModel(model))
if err != nil {
return r.UpdateStatus(ctx, instance, nil, err)
}
msg = strings.Join([]string{msg, res}, "\n")
}
default:
return r.UpdateStatus(ctx, instance, nil, fmt.Errorf("unsupported service type: %s", instance.Spec.Type))
}
Expand Down
3 changes: 3 additions & 0 deletions controllers/base/prompt_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package controllers

import (
"context"
"errors"
"fmt"
"reflect"

Expand Down Expand Up @@ -139,6 +140,8 @@ func (r *PromptReconciler) CallLLM(ctx context.Context, logger logr.Logger, prom
if err != nil {
return r.UpdateStatus(ctx, prompt, nil, err)
}
case llms.Gemini:
return r.UpdateStatus(ctx, prompt, nil, errors.New("not implemented yet"))
default:
llmClient = llms.NewUnknowLLM()
}
Expand Down
14 changes: 13 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ require (
)

require (
cloud.google.com/go/ai v0.3.0 // indirect
cloud.google.com/go/longrunning v0.5.4 // indirect
github.com/KyleBanks/depth v1.2.1 // indirect
github.com/Masterminds/goutils v1.1.1 // indirect
github.com/Masterminds/semver/v3 v3.2.0 // indirect
Expand All @@ -53,6 +55,10 @@ require (
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.16.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/generative-ai-go v0.5.0 // indirect
github.com/google/s2a-go v0.1.7 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.0 // indirect
github.com/goph/emperror v0.17.2 // indirect
github.com/gorilla/websocket v1.5.0 // indirect
github.com/hashicorp/golang-lru/v2 v2.0.3 // indirect
Expand Down Expand Up @@ -81,10 +87,16 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
github.com/yargevad/filepathx v1.0.0 // indirect
go.opencensus.io v0.24.0 // indirect
go.starlark.net v0.0.0-20230302034142-4b1e35fe2254 // indirect
golang.org/x/arch v0.6.0 // indirect
golang.org/x/sync v0.5.0 // indirect
golang.org/x/tools v0.16.1 // indirect
google.golang.org/api v0.152.0 // indirect
google.golang.org/genproto v0.0.0-20231120223509-83a465c0220f // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20231211222908-989df2bf70f3 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20231211222908-989df2bf70f3 // indirect
google.golang.org/grpc v1.60.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
)

Expand All @@ -102,7 +114,7 @@ require (
github.com/andybalholm/cascadia v1.3.2 // indirect
github.com/aymerick/douceur v0.2.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dlclark/regexp2 v1.8.1 // indirect
github.com/emicklei/go-restful v2.9.5+incompatible // indirect
Expand Down
3 changes: 2 additions & 1 deletion go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,9 @@ github.com/certifi/gocertifi v0.0.0-20191021191039-0944d244cd40/go.mod h1:sGbDF6
github.com/certifi/gocertifi v0.0.0-20200922220541-2c3bb06c6054/go.mod h1:sGbDF6GwGcLpkNXPUTkMRoywsNa/ol15pxFe6ERfguA=
github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d h1:77cEq6EriyTZ0g/qfRdp61a3Uu/AWrgIq2s0ClJV1g0=
Expand Down
2 changes: 2 additions & 0 deletions pkg/embeddings/embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ type EmbeddingType string
const (
OpenAI EmbeddingType = "openai"
ZhiPuAI EmbeddingType = "zhipuai"
Gemini EmbeddingType = "gemini"
Unknown EmbeddingType = "unknown"
)

var (
ZhiPuAIModels = []string{"text_embedding"}
OpenAIModels = []string{"text-embedding-ada-002"}
GeminiModels = []string{"embedding-001"}
)
2 changes: 2 additions & 0 deletions pkg/evaluation/jobs.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ func JudgeJobGenerator(ctx context.Context, c client.Client) func(*evav1alpha1.R
model = "gtp4"
case llms.ZhiPuAI:
model = "glm-4"
case llms.Gemini:
model = "gemini-pro"
default:
return nil, fmt.Errorf("not support type %s", llm.Spec.Type)
}
Expand Down
20 changes: 20 additions & 0 deletions pkg/langchainwrap/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"fmt"

langchaingoembeddings "github.com/tmc/langchaingo/embeddings"
"github.com/tmc/langchaingo/llms/googleai"
"github.com/tmc/langchaingo/llms/openai"
"k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/controller-runtime/pkg/client"
Expand Down Expand Up @@ -68,6 +69,25 @@ func GetLangchainEmbedder(ctx context.Context, e *v1alpha1.Embedder, c client.Cl
return nil, err
}
return langchaingoembeddings.NewEmbedder(llm, opts...)
case embeddings.Gemini:
apiKey, err := e.AuthAPIKey(ctx, c)
if err != nil {
return nil, err
}

if model == "" {
models := e.GetModelList()
if len(models) == 0 {
return nil, errors.New("no valid models for this Embedder")
}
model = models[0]
}

llm, err := googleai.New(ctx, googleai.WithAPIKey(apiKey), googleai.WithDefaultEmbeddingModel(model))
if err != nil {
return nil, err
}
return langchaingoembeddings.NewEmbedder(llm, opts...)
}
case v1alpha1.ProviderTypeWorker:
gateway, err := config.GetGateway(ctx, c)
Expand Down
10 changes: 10 additions & 0 deletions pkg/langchainwrap/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"os"

langchainllms "github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/googleai"
"github.com/tmc/langchaingo/llms/openai"
"k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/controller-runtime/pkg/client"
Expand Down Expand Up @@ -60,6 +61,15 @@ func GetLangchainLLM(ctx context.Context, llm *v1alpha1.LLM, c client.Client, mo
model = models[0]
}
return openai.New(openai.WithToken(apiKey), openai.WithBaseURL(llm.Get3rdPartyLLMBaseURL()), openai.WithModel(model))
case llms.Gemini:
if model == "" {
models := llm.GetModelList()
if len(models) == 0 {
return nil, errors.New("no valid models for this LLM")
}
model = models[0]
}
return googleai.New(ctx, googleai.WithAPIKey(apiKey), googleai.WithDefaultModel(model))
}
case v1alpha1.ProviderTypeWorker:
gateway, err := config.GetGateway(ctx, c)
Expand Down
6 changes: 5 additions & 1 deletion pkg/llms/llms.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,14 @@ const (
OpenAI LLMType = "openai"
ZhiPuAI LLMType = "zhipuai"
DashScope LLMType = "dashscope"
Gemini LLMType = "gemini"
Unknown LLMType = "unknown"
)

var OpenAIModels = []string{"gpt-3.5", "gpt-3.5-turbo"}
var (
OpenAIModels = []string{"gpt-3.5", "gpt-3.5-turbo"}
GeminiModels = []string{"gemini-pro"}
)

var (
ZhiPuAILite string = "chatglm_lite"
Expand Down
20 changes: 17 additions & 3 deletions tests/example-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,14 @@ kubectl apply -f config/samples/arcadia_v1alpha1_versioneddataset.yaml
waitCRDStatusReady "VersionedDataset" "arcadia" "dataset-playground-v1"

info "7.3 create embedder and wait it ready"
kubectl apply -f config/samples/arcadia_v1alpha1_embedders.yaml
waitCRDStatusReady "Embedders" "arcadia" "zhipuai-embedders-sample"
if [[ $GITHUB_ACTIONS == "true" ]]; then
info "in github action, use gemini"
kubectl apply -f config/samples/arcadia_v1alpha1_embedders_gemini.yaml
else
info "in local, use zhipu"
kubectl apply -f config/samples/arcadia_v1alpha1_embedders_zhipu.yaml
fi
waitCRDStatusReady "Embedders" "arcadia" "embedders-sample"

info "7.4 create knowledgebase and wait it ready"
info "7.4.1 create knowledgebase based on chroma and wait it ready"
Expand Down Expand Up @@ -348,7 +354,15 @@ fi

info "8 validate simple app can work normally"
info "Prepare dependent LLM service"
kubectl apply -f config/samples/app_shared_llm_service.yaml
if [[ $GITHUB_ACTIONS == "true" ]]; then
info "in github action, use gemini"
sed -i 's/model: chatglm_turbo/model: gemini-pro/g' config/samples/*
sed -i 's/model: glm-4/model: gemini-pro/g' config/samples/*
kubectl apply -f config/samples/app_shared_llm_service_gemini.yaml
else
info "in local, use zhipu"
kubectl apply -f config/samples/app_shared_llm_service_zhipu.yaml
fi

info "8.1 app of llmchain"
kubectl apply -f config/samples/app_llmchain_englishteacher.yaml
Expand Down

0 comments on commit 714bfb8

Please sign in to comment.