From c04fd93787f7373d51ad2d24b7c7c917cf48b26c Mon Sep 17 00:00:00 2001 From: Muzzammil Shahid Date: Fri, 20 Sep 2024 19:03:46 +0500 Subject: [PATCH 1/4] Update wampproto to latest --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index ebffc54..2be30af 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/gobwas/ws v1.4.0 github.com/projectdiscovery/ratelimit v0.0.50 github.com/stretchr/testify v1.9.0 - github.com/xconnio/wampproto-go v0.0.0-20240801143427-b722ee9231d0 + github.com/xconnio/wampproto-go v0.0.0-20240920091217-fd8f83f21c54 github.com/xconnio/wampproto-protobuf/go v0.0.0-20240706133816-0ca5f0268ce9 golang.org/x/exp v0.0.0-20240525044651-4c93da0ed11d gopkg.in/yaml.v3 v3.0.1 diff --git a/go.sum b/go.sum index 445a9a5..25900d8 100644 --- a/go.sum +++ b/go.sum @@ -41,8 +41,8 @@ github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAh github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= -github.com/xconnio/wampproto-go v0.0.0-20240801143427-b722ee9231d0 h1:IU8Sn5EkI0vO/r8C36lEWM3wiSMi7NpxNmpUmt2fUyg= -github.com/xconnio/wampproto-go v0.0.0-20240801143427-b722ee9231d0/go.mod h1:/b7EyR1X9EkOHPQBJGz1KvdjClo1GsalBGIzjQU5+i4= +github.com/xconnio/wampproto-go v0.0.0-20240920091217-fd8f83f21c54 h1:uqKiqnmD6XSnX65WbUUNmIyW4L0oaPeOQPytzrxZPyg= +github.com/xconnio/wampproto-go v0.0.0-20240920091217-fd8f83f21c54/go.mod h1:/b7EyR1X9EkOHPQBJGz1KvdjClo1GsalBGIzjQU5+i4= github.com/xconnio/wampproto-protobuf/go v0.0.0-20240706133816-0ca5f0268ce9 h1:N0W6uTElFFj/nl88fAtCwUw0y0pdHbtn3QPQri/iGsw= github.com/xconnio/wampproto-protobuf/go v0.0.0-20240706133816-0ca5f0268ce9/go.mod h1:k3t5aYBC+1ujppNAaIgu+Kn7oryRSwsP3o362HkAAho= github.com/xhit/go-str2duration/v2 v2.1.0 h1:lxklc02Drh6ynqX+DdPyp5pCKLUQpRT8bp8Ydu2Bstc= From 1000053d25bd711ad0498f43f50f995733815f22 Mon Sep 17 00:00:00 2001 From: Muzzammil Shahid Date: Fri, 20 Sep 2024 19:04:33 +0500 Subject: [PATCH 2/4] Implement progressive call results --- session.go | 98 ++++++++++++++++++++++++++++++++++++++++++------------ types.go | 4 +++ 2 files changed, 81 insertions(+), 21 deletions(-) diff --git a/session.go b/session.go index 0aea680..1cdc688 100644 --- a/session.go +++ b/session.go @@ -16,6 +16,7 @@ import ( type InvocationHandler func(ctx context.Context, invocation *Invocation) *Result type EventHandler func(event *Event) +type ProgressHandler func(result *Result) type Session struct { base BaseSession @@ -29,6 +30,7 @@ type Session struct { unregisterRequests sync.Map registrations sync.Map callRequests sync.Map + progressHandlers sync.Map // publish subscribe data structures subscribeRequests sync.Map @@ -52,6 +54,7 @@ func NewSession(base BaseSession, serializer serializers.Serializer) *Session { unregisterRequests: sync.Map{}, registrations: sync.Map{}, callRequests: sync.Map{}, + progressHandlers: sync.Map{}, subscribeRequests: sync.Map{}, unsubscribeRequests: sync.Map{}, @@ -120,8 +123,23 @@ func (s *Session) processIncomingMessage(msg messages.Message) error { return fmt.Errorf("received RESULT for unknown request") } - req := request.(chan *CallResponse) - req <- &CallResponse{msg: result} + progress, _ := result.Details()[wampproto.OptionProgress].(bool) + if progress { + progressHandler, exists := s.progressHandlers.Load(result.RequestID()) + if exists { + progHandler := progressHandler.(ProgressHandler) + progHandler(&Result{ + Arguments: result.Args(), + KwArguments: result.KwArgs(), + Details: result.Details(), + }) + } + } else { + req := request.(chan *CallResponse) + req <- &CallResponse{msg: result} + s.progressHandlers.Delete(result.RequestID()) + } + case messages.MessageTypeInvocation: invocation := msg.(*messages.Invocation) end, _ := s.registrations.Load(invocation.RegistrationID()) @@ -133,24 +151,45 @@ func (s *Session) processIncomingMessage(msg messages.Message) error { Details: invocation.Details(), } - var msgToSend messages.Message - res := endpoint(context.Background(), inv) - if res.Err != "" { - msgToSend = messages.NewError( - int64(invocation.Type()), invocation.RequestID(), map[string]any{}, res.Err, res.Arguments, res.KwArguments, - ) - } else { - msgToSend = messages.NewYield(invocation.RequestID(), nil, res.Arguments, res.KwArguments) + receiveProgress, _ := invocation.Details()[wampproto.OptionReceiveProgress].(bool) + if receiveProgress { + inv.SendProgress = func(arguments []any, kwArguments map[string]any) error { + yield := messages.NewYield(invocation.RequestID(), map[string]any{"progress": true}, arguments, kwArguments) + payload, err := s.proto.SendMessage(yield) + if err != nil { + return fmt.Errorf("failed to send yield: %w", err) + } + + if err = s.base.Write(payload); err != nil { + return fmt.Errorf("failed to send yield: %w", err) + } + return nil + } } - payload, err := s.proto.SendMessage(msgToSend) - if err != nil { - return fmt.Errorf("failed to send yield: %w", err) - } + go func() { + var msgToSend messages.Message + res := endpoint(context.Background(), inv) + if res.Err != "" { + msgToSend = messages.NewError( + int64(invocation.Type()), invocation.RequestID(), map[string]any{}, res.Err, res.Arguments, res.KwArguments, + ) + } else { + msgToSend = messages.NewYield(invocation.RequestID(), nil, res.Arguments, res.KwArguments) + } + + payload, err := s.proto.SendMessage(msgToSend) + if err != nil { + log.Println("failed to send yield: %w", err) + return + } + + if err = s.base.Write(payload); err != nil { + log.Println("failed to send yield: %w", err) + return + } + }() - if err = s.base.Write(payload); err != nil { - return fmt.Errorf("failed to send yield: %w", err) - } case messages.MessageTypeSubscribed: subscribed := msg.(*messages.Subscribed) request, exists := s.subscribeRequests.Load(subscribed.RequestID()) @@ -323,10 +362,7 @@ func (s *Session) Unregister(registrationID int64) error { } } -func (s *Session) Call(ctx context.Context, procedure string, args []any, kwArgs map[string]any, - options map[string]any) (*Result, error) { - - call := messages.NewCall(s.idGen.NextID(), options, procedure, args, kwArgs) +func (s *Session) call(ctx context.Context, call *messages.Call) (*Result, error) { toSend, err := s.proto.SendMessage(call) if err != nil { return nil, err @@ -356,6 +392,26 @@ func (s *Session) Call(ctx context.Context, procedure string, args []any, kwArgs } } +func (s *Session) Call(ctx context.Context, procedure string, args []any, kwArgs map[string]any, + options map[string]any) (*Result, error) { + + call := messages.NewCall(s.idGen.NextID(), options, procedure, args, kwArgs) + return s.call(ctx, call) +} + +func (s *Session) CallProgress(ctx context.Context, procedure string, args []any, kwArgs map[string]any, + options map[string]any, progressHandler ProgressHandler) (*Result, error) { + + call := messages.NewCall(s.idGen.NextID(), options, procedure, args, kwArgs) + if progressHandler == nil { + progressHandler = func(result *Result) {} + } + s.progressHandlers.Store(call.RequestID(), progressHandler) + call.Options()[wampproto.OptionReceiveProgress] = true + + return s.call(ctx, call) +} + func (s *Session) Subscribe(topic string, handler EventHandler, options map[string]any) (*Subscription, error) { subscribe := messages.NewSubscribe(s.idGen.NextID(), options, topic) toSend, err := s.proto.SendMessage(subscribe) diff --git a/types.go b/types.go index 9eef0a4..134a102 100644 --- a/types.go +++ b/types.go @@ -115,11 +115,15 @@ type Event struct { Details map[string]any } +type SendProgress func(arguments []any, kwArguments map[string]any) error + type Invocation struct { Procedure string Arguments []any KwArguments map[string]any Details map[string]any + + SendProgress SendProgress } type Result struct { From 06370316ebf9e40a89ea0734713fe649c6290f09 Mon Sep 17 00:00:00 2001 From: Muzzammil Shahid Date: Fri, 20 Sep 2024 19:05:03 +0500 Subject: [PATCH 3/4] Add test for progressive call results --- client_test.go | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/client_test.go b/client_test.go index ffad0c3..5a778f4 100644 --- a/client_test.go +++ b/client_test.go @@ -95,3 +95,43 @@ func TestPublishSubscribe(t *testing.T) { log.Println(event) }) } + +func TestProgressiveCallResults(t *testing.T) { + session := connect(t) + + reg, err := session.Register( + "foo.bar.progress", + func(ctx context.Context, invocation *xconn.Invocation) *xconn.Result { + // Send progress + for i := 1; i <= 3; i++ { + err := invocation.SendProgress([]any{i}, nil) + require.NoError(t, err) + } + + // Return final result + return &xconn.Result{Arguments: []any{"done"}} + }, + nil, + ) + require.NoError(t, err) + require.NotNil(t, reg) + + t.Run("ProgressiveCall", func(t *testing.T) { + // Store received progress updates + progressUpdates := make([]int, 0) + + result, err := session.CallProgress(context.Background(), "foo.bar.progress", nil, nil, nil, + func(progressiveResult *xconn.Result) { + progress := int(progressiveResult.Arguments[0].(float64)) + // Collect received progress + progressUpdates = append(progressUpdates, progress) + }) + require.NoError(t, err) + + // Verify progressive updates received correctly + require.Equal(t, []int{1, 2, 3}, progressUpdates) + + // Verify the final result + require.Equal(t, "done", result.Arguments[0]) + }) +} From 227026de655f275aabfdfbfc961278645e5e65eb Mon Sep 17 00:00:00 2001 From: Muzzammil Shahid Date: Fri, 20 Sep 2024 19:05:56 +0500 Subject: [PATCH 4/4] Add example for progressive call results --- .../callee/main.go | 50 +++++++++++++++++++ .../caller/main.go | 32 ++++++++++++ 2 files changed, 82 insertions(+) create mode 100644 examples/rpc_progressive_call_results/callee/main.go create mode 100644 examples/rpc_progressive_call_results/caller/main.go diff --git a/examples/rpc_progressive_call_results/callee/main.go b/examples/rpc_progressive_call_results/callee/main.go new file mode 100644 index 0000000..4122b5e --- /dev/null +++ b/examples/rpc_progressive_call_results/callee/main.go @@ -0,0 +1,50 @@ +package main + +import ( + "context" + "log" + "os" + "os/signal" + "time" + + "github.com/xconnio/xconn-go" +) + +const procedureProgressDownload = "io.xconn.progress.download" + +func main() { + // Create and connect a callee client to server + ctx := context.Background() + client := xconn.Client{} + callee, err := client.Connect(ctx, "ws://localhost:8080/ws", "realm1") + if err != nil { + log.Fatalf("Failed to connect to server: %s", err) + } + defer func() { _ = callee.Leave() }() + + invocationHandler := func(ctx context.Context, invocation *xconn.Invocation) *xconn.Result { + fileSize := 100 // Simulate a file size of 100 units + for i := 0; i <= fileSize; i += 10 { + progress := i * 100 / fileSize + if err := invocation.SendProgress([]any{progress}, nil); err != nil { + return &xconn.Result{Err: "wamp.error.canceled", Arguments: []any{err.Error()}} + } + time.Sleep(500 * time.Millisecond) // Simulate time taken for download + } + + return &xconn.Result{Arguments: []any{"Download complete!"}} + } + + registration, err := callee.Register(procedureProgressDownload, invocationHandler, nil) + if err != nil { + log.Fatalf("Failed to register method: %s", err) + } + defer func() { _ = callee.Unregister(registration.ID) }() + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt) + select { + case <-sigChan: + case <-ctx.Done(): + } +} diff --git a/examples/rpc_progressive_call_results/caller/main.go b/examples/rpc_progressive_call_results/caller/main.go new file mode 100644 index 0000000..a4d49f8 --- /dev/null +++ b/examples/rpc_progressive_call_results/caller/main.go @@ -0,0 +1,32 @@ +package main + +import ( + "context" + "fmt" + "log" + + "github.com/xconnio/xconn-go" +) + +const procedureProgressDownload = "io.xconn.progress.download" + +func main() { + // Create and connect a caller client to server + ctx := context.Background() + client := xconn.Client{} + caller, err := client.Connect(ctx, "ws://localhost:8080/ws", "realm1") + if err != nil { + log.Fatalf("Failed to connect to server: %s", err) + } + defer func() { _ = caller.Leave() }() + + result, err := caller.CallProgress(ctx, procedureProgressDownload, nil, nil, nil, func(result *xconn.Result) { + progress := result.Arguments[0] + fmt.Printf("Download progress: %v%%\n", progress) + }) + if err != nil { + log.Fatalf("Call failed: %s", err) + } + + fmt.Println(result.Arguments[0]) +}