Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【wip】feat: upgrade zhipuai from v3 to v4 #580

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/zhipuai/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func main() {
func sampleInvoke(apiKey string) (*zhipuai.Response, error) {
client := zhipuai.NewZhiPuAI(apiKey)
params := zhipuai.DefaultModelParams()
params.Prompt = []zhipuai.Prompt{
params.Messages = []zhipuai.Message{
{Role: zhipuai.User, Content: "As a kubernetes expert,please answer the following questions."},
}
return client.Invoke(params)
Expand All @@ -78,7 +78,7 @@ func sampleInvoke(apiKey string) (*zhipuai.Response, error) {
func sampleInvokeAsync(apiKey string) (*zhipuai.Response, error) {
client := zhipuai.NewZhiPuAI(apiKey)
params := zhipuai.DefaultModelParams()
params.Prompt = []zhipuai.Prompt{
params.Messages = []zhipuai.Message{
{Role: zhipuai.User, Content: "As a kubernetes expert,please answer the following questions."},
}
return client.AsyncInvoke(params)
Expand All @@ -94,7 +94,7 @@ func getInvokeAsyncResult(apiKey string, taskID string) (*zhipuai.Response, erro
func sampleSSEInvoke(apiKey string) error {
client := zhipuai.NewZhiPuAI(apiKey)
params := zhipuai.DefaultModelParams()
params.Prompt = []zhipuai.Prompt{
params.Messages = []zhipuai.Message{
{Role: zhipuai.User, Content: "As a kubernetes expert,please answer the following questions."},
}
// you can define a customized `handler` on `Event`
Expand Down
10 changes: 6 additions & 4 deletions pkg/llms/zhipuai/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@
)

const (
ZhipuaiModelAPIURL = "https://open.bigmodel.cn/api/paas/v3/model-api"
ZhipuaiModelDefaultTimeout = 30 * time.Second
RetryLimit = 3
ZhipuaiModelAPIURL = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
ZhipuaiModelAPIAsyncURL = "https://open.bigmodel.cn/api/paas/v4/async/chat/completions"
ZhipuaiModelAPIAsyncGetResultURL = "https://https://open.bigmodel.cn/api/paas/v4/async-result/"
ZhipuaiModelDefaultTimeout = 300 * time.Second
RetryLimit = 3
)

type Method string
Expand All @@ -49,7 +51,7 @@
)

func BuildAPIURL(model string, method Method) string {
return fmt.Sprintf("%s/%s/%s", ZhipuaiModelAPIURL, model, method)
return fmt.Sprintf("%s/%s/%s", ZhipuaiModelAPISSEURL, model, method)

Check failure on line 54 in pkg/llms/zhipuai/api.go

View workflow job for this annotation

GitHub Actions / Build & cache Go code

undefined: ZhipuaiModelAPISSEURL
}

var _ llms.LLM = (*ZhiPuAI)(nil)
Expand All @@ -74,7 +76,7 @@
if err := params.Unmarshal(data); err != nil {
return nil, err
}
switch params.Method {

Check failure on line 79 in pkg/llms/zhipuai/api.go

View workflow job for this annotation

GitHub Actions / Build & cache Go code

params.Method undefined (type ModelParams has no field or method Method)
case ZhiPuAIInvoke:
return z.Invoke(params)
case ZhiPuAIAsyncInvoke:
Expand Down Expand Up @@ -110,12 +112,12 @@

// Get result of task async-invoke
func (z *ZhiPuAI) Get(params ModelParams) (*Response, error) {
if params.TaskID == "" {

Check failure on line 115 in pkg/llms/zhipuai/api.go

View workflow job for this annotation

GitHub Actions / Build & cache Go code

params.TaskID undefined (type ModelParams has no field or method TaskID)
return nil, errors.New("TaskID is required when running Get with method AsyncInvoke")
}

// url with task id
url := fmt.Sprintf("%s/%s", BuildAPIURL(params.Model, ZhiPuAIAsyncInvoke), params.TaskID)

Check failure on line 120 in pkg/llms/zhipuai/api.go

View workflow job for this annotation

GitHub Actions / Build & cache Go code

params.TaskID undefined (type ModelParams has no field or method TaskID)
token, err := GenerateToken(z.apiKey, APITokenTTLSeconds)
if err != nil {
return nil, err
Expand All @@ -141,7 +143,7 @@
return nil, err
}

testPrompt := []Prompt{

Check failure on line 146 in pkg/llms/zhipuai/api.go

View workflow job for this annotation

GitHub Actions / Build & cache Go code

undefined: Prompt
{
Role: "user",
Content: "Hello",
Expand All @@ -149,11 +151,11 @@
}

testParam := ModelParams{
Method: ZhiPuAIAsyncInvoke,

Check failure on line 154 in pkg/llms/zhipuai/api.go

View workflow job for this annotation

GitHub Actions / Build & cache Go code

unknown field Method in struct literal of type ModelParams
Model: llms.ZhiPuAILite,
Temperature: 0.95,
TopP: 0.7,
Prompt: testPrompt,

Check failure on line 158 in pkg/llms/zhipuai/api.go

View workflow job for this annotation

GitHub Actions / Build & cache Go code

unknown field Prompt in struct literal of type ModelParams
}

postResponse, err := Post(url, token, testParam, ZhipuaiModelDefaultTimeout)
Expand Down
105 changes: 75 additions & 30 deletions pkg/llms/zhipuai/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
limitations under the License.
*/

// NOTE: Reference zhipuai's python sdk: model_api/params.py
// NOTE: Reference zhipuai's python sdk: model_api/params.py and https://open.bigmodel.cn/dev/api#glm-4

package zhipuai

Expand All @@ -28,51 +28,99 @@
type Role string

const (
System Role = "system"
User Role = "user"
Assistant Role = "assistant"
Tool Role = "tool"
)

var _ llms.ModelParams = (*ModelParams)(nil)

// +kubebuilder:object:generate=true
// ZhiPuAIParams defines the params of ZhiPuAI Prompt Call
type ModelParams struct {
// Method used for this prompt call
Method Method `json:"method,omitempty"`
// The model to be called
Model string `json:"model"`
// Contents
Messages []Message `json:"messages"`
// Passed by the client, need to ensure the uniqueness;
// it isused to distinguish between the unique identity of each request,
// the platform will be generated by default when the client does not pass.
RequestID string `json:"request_id,omitempty"`
// Sampling strategy is enabled when do_sample is true,
// and sampling strategies temperature and top_p are not effective when do_sample is false. TODO
DoSample bool `json:"do_sample,omitempty"`
// Stream seting to true means use sse

Check failure on line 53 in pkg/llms/zhipuai/params.go

View workflow job for this annotation

GitHub Actions / misspellings

seting ==> setting
Stream bool `json:"stream,omitempty"`
// Temperature is float in zhipuai (0.0,1.0], default is 0.95
Temperature float32 `json:"temperature,omitempty"`
// TopP is float in zhipuai, (0.0, 1.0) default is 0.7
TopP float32 `json:"top_p,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
// The model will stop generating when it encounters a character formulated by stop,
// currently only a single stop word is supported, in the format of ["stop_word1"].
Stop []string `json:"stop,omitempty"`
// The list of tools available to LLM, the tools field counts the tokens and
// is also limited by the length of the tokens.
Tools []ToolReq `json:"tools,omitempty"`
// only auto now
ToolChoice string `json:"tool_choice,omitempty"`
}

// Model used for this prompt call
Model string `json:"model,omitempty"`
type ToolReq struct {
// function、retrieval、web_search
Type string `json:"type"`
Function Function `json:"function,omitempty"`
Retrieval Retrieval `json:"retrieval,omitempty"`
WebSearch WebSearch `json:"web_search,omitempty"`
}

// Temperature is float in zhipuai
Temperature float32 `json:"temperature,omitempty"`
// TopP is float in zhipuai
TopP float32 `json:"top_p,omitempty"`
// Contents
Prompt []Prompt `json:"prompt"`
type Function struct {
// Can only contain a-z, A-Z, 0-9, underscores and centered lines. Maximum length limit is 64
Name string `json:"name"`
Description string `json:"description"`
Parameters any `json:"parameters,omitempty"`
}

// TaskID is used for getting result of AsyncInvoke
TaskID string `json:"task_id,omitempty"`
type Retrieval struct {
KnowledgeID string `json:"knowledge_id"`
PromptTemplate string `json:"prompt_template,omitempty"`
}

// Incremental is only Used for SSE Invoke
Incremental bool `json:"incremental,omitempty"`
type WebSearch struct {
Enable bool `json:"enable,omitempty"`
SearchQuery string `json:"search_query,omitempty"`
}

// +kubebuilder:object:generate=true
// Prompt defines the content of ZhiPuAI Prompt Call
type Prompt struct {
// Message defines the content of ZhiPuAI Prompt Call
type Message struct {
Role Role `json:"role,omitempty"`
Content string `json:"content,omitempty"`
// content and tool_calls must choose one
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
// only used when role is `tool`
ToolCallID string `json:"tool_call_id,omitempty"`
}

type ToolCall struct {
ID string `json:"id"`
// web_search、retrieval、function
Type string `json:"type"`
Function FunctionToolCall `json:"function"`
}

type FunctionToolCall struct {
Name string `json:"name"`
Parameters any `json:"parameters"`
}

func DefaultModelParams() ModelParams {
// TODO: should allow user to configure the temperature and top_p of inference
// use 0.8 and 0.7 for now
return ModelParams{
Model: llms.ZhiPuAILite,
Method: ZhiPuAIInvoke,
Temperature: 0.8, // more accurate?
TopP: 0.7,
Prompt: []Prompt{},
Messages: make([]Message, 0),
}
}

Expand All @@ -94,27 +142,24 @@
if a.Model == "" && b.Model != "" {
a.Model = b.Model
}
if a.Method == "" && b.Method != "" {
a.Method = b.Method
}
if a.Temperature == 0 && b.Temperature != 0 {
a.Temperature = b.Temperature
}
if a.TopP == 0 && b.TopP != 0 {
a.TopP = b.TopP
}
if !a.Incremental && b.Incremental {
a.Incremental = b.Incremental
if !a.Stream && b.Stream {
a.Stream = b.Stream
}
if len(a.Prompt) == 0 && len(b.Prompt) > 0 {
a.Prompt = b.Prompt
if len(a.Messages) == 0 && len(b.Messages) > 0 {
a.Messages = b.Messages
}
return a
}

func ValidateModelParams(params ModelParams) error {
if params.Model == "" || params.Method == "" {
return errors.New("model or method is required")
if params.Model == "" {
return errors.New("model is required")
}

if params.Temperature < 0 || params.Temperature > 1 {
Expand Down
27 changes: 19 additions & 8 deletions pkg/llms/zhipuai/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,16 @@ func (response *Response) String() string {
}

type Data struct {
RequestID string `json:"request_id,omitempty"`
TaskID string `json:"task_id,omitempty"`
TaskStatus string `json:"task_status,omitempty"`
Usage Usage `json:"usage,omitempty"`

// for async
RequestID string `json:"request_id,omitempty"`
TaskID string `json:"id,omitempty"`
// The request creation time, a Unix timestamp in seconds.
Created int64 `json:"created"`
Model string `json:"model"`
Choices []Choice `json:"choices,omitempty"`
Usage Usage `json:"usage,omitempty"`
// for async
TaskStatus string `json:"task_status,omitempty"`
}

type EmbeddingData struct {
Expand All @@ -100,12 +104,19 @@ type EmbeddingText struct {
}

type Usage struct {
TotalTokens int `json:"total_tokens,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"`
PromptTokens int `json:"prompt_tokens,omitempty"`
CompletionTokens int `json:"completion_tokens,omitempty"`
}

type Choice struct {
Content string `json:"content"`
Role string `json:"role"`
Index int `json:"index"`
// The reason for the termination of the model's reasoning.
// `stop` represents the natural end of reasoning or a trigger stop word.
// `tool_calls` represents the model hit function.
// `length` represents the maximum length of tokens reached.
FinishReason string `json:"finish_reason"`
Message Message `json:"message"`
}

const (
Expand Down
Loading