diff --git a/go/ai/generate.go b/go/ai/generate.go index 7d7c958fc..2e0f15ae0 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -48,9 +48,12 @@ type ModelMetadata struct { Supports ModelCapabilities } +type ModelMiddleware = core.Middleware[*GenerateRequest, *GenerateResponse, *GenerateResponseChunk] +type ModelFunc = core.Func[*GenerateRequest, *GenerateResponse, *GenerateResponseChunk] + // DefineModel registers the given generate function as an action, and returns a // [ModelAction] that runs it. -func DefineModel(provider, name string, metadata *ModelMetadata, generate func(context.Context, *GenerateRequest, ModelStreamingCallback) (*GenerateResponse, error)) *ModelAction { +func DefineModel(provider, name string, metadata *ModelMetadata, middleware []ModelMiddleware, generate ModelFunc) *ModelAction { metadataMap := map[string]any{} if metadata != nil { if metadata.Label != "" { @@ -66,7 +69,7 @@ func DefineModel(provider, name string, metadata *ModelMetadata, generate func(c } return core.DefineStreamingAction(provider, name, atype.Model, map[string]any{ "model": metadataMap, - }, generate) + }, middleware, generate) } // LookupModel looks up a [ModelAction] registered by [DefineModel]. diff --git a/go/ai/generator_test.go b/go/ai/generator_test.go index 01d1ab13a..9e1bf452e 100644 --- a/go/ai/generator_test.go +++ b/go/ai/generator_test.go @@ -15,6 +15,8 @@ package ai import ( + "context" + "fmt" "strings" "testing" ) @@ -189,6 +191,50 @@ func TestValidCandidate(t *testing.T) { }) } +var echoModel = DefineModel("echo", "echo", nil, nil, func(ctx context.Context, req *GenerateRequest, cb func(context.Context, *GenerateResponseChunk) error) (*GenerateResponse, error) { + input := req.Messages[0].Content[0].Text + output := fmt.Sprintf("Echo: %q", input) + + r := &GenerateResponse{ + Candidates: []*Candidate{ + { + Message: &Message{ + Content: []*Part{ + NewTextPart(output), + }, + }, + }, + }, + Request: req, + } + return r, nil +}) + +func TestGenerate(t *testing.T) { + response, err := Generate(context.Background(), echoModel, &GenerateRequest{ + Messages: []*Message{ + { + Role: RoleUser, + Content: []*Part{ + NewTextPart("banana"), + }, + }, + }, + }, nil) + if err != nil { + t.Fatal(err) + } + + got, err := response.Text() + if err != nil { + t.Fatal(err) + } + want := "Echo: \"banana\"" + if got != want { + t.Errorf("Text() == %q, want %q", got, want) + } +} + func JSONMarkdown(text string) string { return "```json\n" + text + "\n```" } diff --git a/go/core/action.go b/go/core/action.go index 5a30c9467..08e16d67e 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -28,6 +28,36 @@ import ( "github.com/invopop/jsonschema" ) +// Middleware is a function that takes in an action handler function and +// returns a new handler function that might be changing input/output in +// some way. +// +// Middleware functions can: +// - execute arbitrary code; +// - change the request and response; +// - terminate response by returning a response (or error); +// - call the next middleware function. +type Middleware[I, O, S any] func(Func[I, O, S]) Func[I, O, S] + +// Middlewares returns an array of middlewares that are passes in as an argument. +// core.Middlewares(apple, banana) is identical to []core.Middleware[InputType, OutputType]{apple, banana} +func Middlewares[I, O, S any](ms ...Middleware[I, O, S]) []Middleware[I, O, S] { + return ms +} + +// ChainMiddleware creates a new Middleware that applies a sequence of +// Middlewares, so that they execute in the given order when handling action +// request. +// In other words, ChainMiddleware(m1, m2)(handler) = m1(m2(handler)) +func ChainMiddleware[I, O, S any](middlewares ...Middleware[I, O, S]) Middleware[I, O, S] { + return func(h Func[I, O, S]) Func[I, O, S] { + for i := range middlewares { + h = middlewares[len(middlewares)-1-i](h) + } + return h + } +} + // Func is the type of function that Actions and Flows execute. // It takes an input of type Int and returns an output of type Out, optionally // streaming values of type Stream incrementally by invoking a callback. @@ -63,6 +93,7 @@ type Action[In, Out, Stream any] struct { // optional description string metadata map[string]any + middleware []Middleware[In, Out, Stream] } // See js/core/src/action.ts @@ -78,18 +109,18 @@ func defineAction[In, Out any](r *registry, provider, name string, atype atype.A return a } -func DefineStreamingAction[In, Out, Stream any](provider, name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] { - return defineStreamingAction(globalRegistry, provider, name, atype, metadata, fn) +func DefineStreamingAction[In, Out, Stream any](provider, name string, atype atype.ActionType, metadata map[string]any, middleware []Middleware[In, Out, Stream], fn Func[In, Out, Stream]) *Action[In, Out, Stream] { + return defineStreamingAction(globalRegistry, provider, name, atype, metadata, middleware, fn) } -func defineStreamingAction[In, Out, Stream any](r *registry, provider, name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] { - a := newStreamingAction(name, atype, metadata, fn) +func defineStreamingAction[In, Out, Stream any](r *registry, provider, name string, atype atype.ActionType, metadata map[string]any, middleware []Middleware[In, Out, Stream], fn Func[In, Out, Stream]) *Action[In, Out, Stream] { + a := newStreamingAction(name, atype, metadata, middleware, fn) r.registerAction(provider, a) return a } func DefineCustomAction[In, Out, Stream any](provider, name string, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] { - return DefineStreamingAction(provider, name, atype.Custom, metadata, fn) + return DefineStreamingAction(provider, name, atype.Custom, metadata, nil, fn) } // DefineActionWithInputSchema creates a new Action and registers it. @@ -108,13 +139,13 @@ func defineActionWithInputSchema[Out any](r *registry, provider, name string, at // newAction creates a new Action with the given name and non-streaming function. func newAction[In, Out any](name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] { - return newStreamingAction(name, atype, metadata, func(ctx context.Context, in In, cb NoStream) (Out, error) { + return newStreamingAction(name, atype, metadata, nil, func(ctx context.Context, in In, cb NoStream) (Out, error) { return fn(ctx, in) }) } // newStreamingAction creates a new Action with the given name and streaming function. -func newStreamingAction[In, Out, Stream any](name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] { +func newStreamingAction[In, Out, Stream any](name string, atype atype.ActionType, metadata map[string]any, middleware []Middleware[In, Out, Stream], fn Func[In, Out, Stream]) *Action[In, Out, Stream] { var i In var o Out return &Action[In, Out, Stream]{ @@ -127,6 +158,7 @@ func newStreamingAction[In, Out, Stream any](name string, atype atype.ActionType inputSchema: inferJSONSchema(i), outputSchema: inferJSONSchema(o), metadata: metadata, + middleware: middleware, } } @@ -169,6 +201,7 @@ func (a *Action[In, Out, Stream]) Run(ctx context.Context, input In, cb func(con // This action has probably not been registered. tstate = globalRegistry.tstate } + return tracing.RunInNewSpan(ctx, tstate, a.name, "action", false, input, func(ctx context.Context, input In) (Out, error) { start := time.Now() @@ -178,7 +211,8 @@ func (a *Action[In, Out, Stream]) Run(ctx context.Context, input In, cb func(con } var output Out if err == nil { - output, err = a.fn(ctx, input, cb) + dispatch := ChainMiddleware(a.middleware...) + output, err = dispatch(a.fn)(ctx, input, cb) if err == nil { if err = validateValue(output, a.outputSchema); err != nil { err = fmt.Errorf("invalid output: %w", err) diff --git a/go/core/action_test.go b/go/core/action_test.go index a73b0fd49..c900ae657 100644 --- a/go/core/action_test.go +++ b/go/core/action_test.go @@ -27,6 +27,22 @@ func inc(_ context.Context, x int) (int, error) { return x + 1, nil } +func wrapRequest(next Func[string, string, struct{}]) Func[string, string, struct{}] { + return func(ctx context.Context, request string, _ NoStream) (string, error) { + return next(ctx, "("+request+")", nil) + } +} + +func wrapResponse(next Func[string, string, struct{}]) Func[string, string, struct{}] { + return func(ctx context.Context, request string, _ NoStream) (string, error) { + nextResponse, err := next(ctx, request, nil) + if err != nil { + return "", err + } + return "[" + nextResponse + "]", nil + } +} + func TestActionRun(t *testing.T) { a := newAction("inc", atype.Custom, nil, inc) got, err := a.Run(context.Background(), 3, nil) @@ -70,7 +86,7 @@ func count(ctx context.Context, n int, cb func(context.Context, int) error) (int func TestActionStreaming(t *testing.T) { ctx := context.Background() - a := newStreamingAction("count", atype.Custom, nil, count) + a := newStreamingAction("count", atype.Custom, nil, nil, count) const n = 3 // Non-streaming. @@ -126,3 +142,43 @@ func TestActionTracing(t *testing.T) { } t.Fatalf("did not find trace named %q", actionName) } + +func TestActionMiddleware(t *testing.T) { + ctx := context.Background() + + sayHello := newStreamingAction("hello", atype.Custom, nil, Middlewares(wrapRequest, wrapResponse), func(ctx context.Context, input string, _ NoStream) (string, error) { + return "Hello " + input, nil + }) + + got, err := sayHello.Run(ctx, "Pavel", nil) + if err != nil { + t.Fatal(err) + } + want := "[Hello (Pavel)]" + if got != want { + t.Fatalf("got %v, want %v", got, want) + } +} + +func TestActionInterruptedMiddleware(t *testing.T) { + ctx := context.Background() + + interrupt := func(next Func[string, string, struct{}]) Func[string, string, struct{}] { + return func(ctx context.Context, request string, _ NoStream) (string, error) { + return "interrupt (request: \"" + request + "\")", nil + } + } + + a := newStreamingAction("hello", atype.Custom, nil, Middlewares(wrapRequest, interrupt, wrapResponse), func(ctx context.Context, input string, _ NoStream) (string, error) { + return "Hello " + input, nil + }) + + got, err := a.Run(ctx, "Pavel", nil) + if err != nil { + t.Fatal(err) + } + want := "interrupt (request: \"(Pavel)\")" + if got != want { + t.Fatalf("got %v, want %v", got, want) + } +} diff --git a/go/core/flow.go b/go/core/flow.go index 63b0234cf..cfae1403e 100644 --- a/go/core/flow.go +++ b/go/core/flow.go @@ -260,7 +260,7 @@ func (f *Flow[In, Out, Stream]) action() *Action[*flowInstruction[In], *flowStat tracing.SetCustomMetadataAttr(ctx, "flow:wrapperAction", "true") return f.runInstruction(ctx, inst, streamingCallback[Stream](cb)) } - return newStreamingAction(f.name, atype.Flow, metadata, cback) + return newStreamingAction(f.name, atype.Flow, metadata, nil, cback) } // runInstruction performs one of several actions on a flow, as determined by msg. diff --git a/go/plugins/dotprompt/genkit_test.go b/go/plugins/dotprompt/genkit_test.go index 0890d8d50..c5ac6f8b5 100644 --- a/go/plugins/dotprompt/genkit_test.go +++ b/go/plugins/dotprompt/genkit_test.go @@ -42,7 +42,7 @@ func testGenerate(ctx context.Context, req *ai.GenerateRequest, cb func(context. } func TestExecute(t *testing.T) { - testModel := ai.DefineModel("test", "test", nil, testGenerate) + testModel := ai.DefineModel("test", "test", nil, nil, testGenerate) p, err := New("TestExecute", "TestExecute", Config{ModelAction: testModel}) if err != nil { t.Fatal(err) diff --git a/go/plugins/googleai/googleai.go b/go/plugins/googleai/googleai.go index 9b8128b82..9b59ff713 100644 --- a/go/plugins/googleai/googleai.go +++ b/go/plugins/googleai/googleai.go @@ -97,7 +97,7 @@ func defineModel(name string, client *genai.Client) { }, } g := generator{model: name, client: client} - ai.DefineModel(provider, name, meta, g.generate) + ai.DefineModel(provider, name, meta, nil, g.generate) } func defineEmbedder(name string, client *genai.Client) { diff --git a/go/plugins/vertexai/vertexai.go b/go/plugins/vertexai/vertexai.go index 084eb7c6a..1e2e7cf52 100644 --- a/go/plugins/vertexai/vertexai.go +++ b/go/plugins/vertexai/vertexai.go @@ -81,7 +81,7 @@ func initModels(ctx context.Context, cfg Config) error { }, } g := &generator{model: name, client: gclient} - ai.DefineModel(provider, name, meta, g.generate) + ai.DefineModel(provider, name, meta, nil, g.generate) } return nil } diff --git a/js/flow/src/flow.ts b/js/flow/src/flow.ts index e251f99b0..0e228ea85 100644 --- a/js/flow/src/flow.ts +++ b/js/flow/src/flow.ts @@ -850,7 +850,7 @@ export function startFlowsServer(params?: { flows.forEach((f) => { const flowPath = `/${pathPrefix}${f.name}`; logger.info(` - ${flowPath}`); - // Add middlware + // Add middleware f.middleware?.forEach((m) => { app.post(flowPath, m); });