Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement progressive call results #53

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,43 @@ func TestPublishSubscribe(t *testing.T) {
log.Println(event)
})
}

func TestProgressiveCallResults(t *testing.T) {
session := connect(t)

reg, err := session.Register(
"foo.bar.progress",
func(ctx context.Context, invocation *xconn.Invocation) *xconn.Result {
// Send progress
for i := 1; i <= 3; i++ {
err := invocation.SendProgress([]any{i}, nil)
require.NoError(t, err)
}

// Return final result
return &xconn.Result{Arguments: []any{"done"}}
},
nil,
)
require.NoError(t, err)
require.NotNil(t, reg)

t.Run("ProgressiveCall", func(t *testing.T) {
// Store received progress updates
progressUpdates := make([]int, 0)

result, err := session.CallProgress(context.Background(), "foo.bar.progress", nil, nil, nil,
func(progressiveResult *xconn.Result) {
progress := int(progressiveResult.Arguments[0].(float64))
// Collect received progress
progressUpdates = append(progressUpdates, progress)
})
require.NoError(t, err)

// Verify progressive updates received correctly
require.Equal(t, []int{1, 2, 3}, progressUpdates)

// Verify the final result
require.Equal(t, "done", result.Arguments[0])
})
}
50 changes: 50 additions & 0 deletions examples/rpc_progressive_call_results/callee/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package main

import (
"context"
"log"
"os"
"os/signal"
"time"

"github.com/xconnio/xconn-go"
)

const procedureProgressDownload = "io.xconn.progress.download"

func main() {
// Create and connect a callee client to server
ctx := context.Background()
client := xconn.Client{}
callee, err := client.Connect(ctx, "ws://localhost:8080/ws", "realm1")
if err != nil {
log.Fatalf("Failed to connect to server: %s", err)
}
defer func() { _ = callee.Leave() }()

invocationHandler := func(ctx context.Context, invocation *xconn.Invocation) *xconn.Result {
fileSize := 100 // Simulate a file size of 100 units
for i := 0; i <= fileSize; i += 10 {
progress := i * 100 / fileSize
if err := invocation.SendProgress([]any{progress}, nil); err != nil {
return &xconn.Result{Err: "wamp.error.canceled", Arguments: []any{err.Error()}}
}
time.Sleep(500 * time.Millisecond) // Simulate time taken for download
}

return &xconn.Result{Arguments: []any{"Download complete!"}}
}

registration, err := callee.Register(procedureProgressDownload, invocationHandler, nil)
if err != nil {
log.Fatalf("Failed to register method: %s", err)
}
defer func() { _ = callee.Unregister(registration.ID) }()

sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt)
select {
case <-sigChan:
case <-ctx.Done():
}
}
32 changes: 32 additions & 0 deletions examples/rpc_progressive_call_results/caller/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package main

import (
"context"
"fmt"
"log"

"github.com/xconnio/xconn-go"
)

const procedureProgressDownload = "io.xconn.progress.download"

func main() {
// Create and connect a caller client to server
ctx := context.Background()
client := xconn.Client{}
caller, err := client.Connect(ctx, "ws://localhost:8080/ws", "realm1")
if err != nil {
log.Fatalf("Failed to connect to server: %s", err)
}
defer func() { _ = caller.Leave() }()

result, err := caller.CallProgress(ctx, procedureProgressDownload, nil, nil, nil, func(result *xconn.Result) {
progress := result.Arguments[0]
fmt.Printf("Download progress: %v%%\n", progress)
})
if err != nil {
log.Fatalf("Call failed: %s", err)
}

fmt.Println(result.Arguments[0])
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ require (
github.com/gobwas/ws v1.4.0
github.com/projectdiscovery/ratelimit v0.0.50
github.com/stretchr/testify v1.9.0
github.com/xconnio/wampproto-go v0.0.0-20240801143427-b722ee9231d0
github.com/xconnio/wampproto-go v0.0.0-20240920091217-fd8f83f21c54
github.com/xconnio/wampproto-protobuf/go v0.0.0-20240706133816-0ca5f0268ce9
golang.org/x/exp v0.0.0-20240525044651-4c93da0ed11d
gopkg.in/yaml.v3 v3.0.1
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAh
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/xconnio/wampproto-go v0.0.0-20240801143427-b722ee9231d0 h1:IU8Sn5EkI0vO/r8C36lEWM3wiSMi7NpxNmpUmt2fUyg=
github.com/xconnio/wampproto-go v0.0.0-20240801143427-b722ee9231d0/go.mod h1:/b7EyR1X9EkOHPQBJGz1KvdjClo1GsalBGIzjQU5+i4=
github.com/xconnio/wampproto-go v0.0.0-20240920091217-fd8f83f21c54 h1:uqKiqnmD6XSnX65WbUUNmIyW4L0oaPeOQPytzrxZPyg=
github.com/xconnio/wampproto-go v0.0.0-20240920091217-fd8f83f21c54/go.mod h1:/b7EyR1X9EkOHPQBJGz1KvdjClo1GsalBGIzjQU5+i4=
github.com/xconnio/wampproto-protobuf/go v0.0.0-20240706133816-0ca5f0268ce9 h1:N0W6uTElFFj/nl88fAtCwUw0y0pdHbtn3QPQri/iGsw=
github.com/xconnio/wampproto-protobuf/go v0.0.0-20240706133816-0ca5f0268ce9/go.mod h1:k3t5aYBC+1ujppNAaIgu+Kn7oryRSwsP3o362HkAAho=
github.com/xhit/go-str2duration/v2 v2.1.0 h1:lxklc02Drh6ynqX+DdPyp5pCKLUQpRT8bp8Ydu2Bstc=
Expand Down
98 changes: 77 additions & 21 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

type InvocationHandler func(ctx context.Context, invocation *Invocation) *Result
type EventHandler func(event *Event)
type ProgressHandler func(result *Result)

type Session struct {
base BaseSession
Expand All @@ -29,6 +30,7 @@ type Session struct {
unregisterRequests sync.Map
registrations sync.Map
callRequests sync.Map
progressHandlers sync.Map

// publish subscribe data structures
subscribeRequests sync.Map
Expand All @@ -52,6 +54,7 @@ func NewSession(base BaseSession, serializer serializers.Serializer) *Session {
unregisterRequests: sync.Map{},
registrations: sync.Map{},
callRequests: sync.Map{},
progressHandlers: sync.Map{},

subscribeRequests: sync.Map{},
unsubscribeRequests: sync.Map{},
Expand Down Expand Up @@ -120,8 +123,23 @@ func (s *Session) processIncomingMessage(msg messages.Message) error {
return fmt.Errorf("received RESULT for unknown request")
}

req := request.(chan *CallResponse)
req <- &CallResponse{msg: result}
progress, _ := result.Details()[wampproto.OptionProgress].(bool)
if progress {
progressHandler, exists := s.progressHandlers.Load(result.RequestID())
if exists {
progHandler := progressHandler.(ProgressHandler)
progHandler(&Result{
Arguments: result.Args(),
KwArguments: result.KwArgs(),
Details: result.Details(),
})
}
} else {
req := request.(chan *CallResponse)
req <- &CallResponse{msg: result}
s.progressHandlers.Delete(result.RequestID())
}

case messages.MessageTypeInvocation:
invocation := msg.(*messages.Invocation)
end, _ := s.registrations.Load(invocation.RegistrationID())
Expand All @@ -133,24 +151,45 @@ func (s *Session) processIncomingMessage(msg messages.Message) error {
Details: invocation.Details(),
}

var msgToSend messages.Message
res := endpoint(context.Background(), inv)
if res.Err != "" {
msgToSend = messages.NewError(
int64(invocation.Type()), invocation.RequestID(), map[string]any{}, res.Err, res.Arguments, res.KwArguments,
)
} else {
msgToSend = messages.NewYield(invocation.RequestID(), nil, res.Arguments, res.KwArguments)
receiveProgress, _ := invocation.Details()[wampproto.OptionReceiveProgress].(bool)
if receiveProgress {
inv.SendProgress = func(arguments []any, kwArguments map[string]any) error {
yield := messages.NewYield(invocation.RequestID(), map[string]any{"progress": true}, arguments, kwArguments)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should refer to the const that's created in the dealer.go

payload, err := s.proto.SendMessage(yield)
if err != nil {
return fmt.Errorf("failed to send yield: %w", err)
}

if err = s.base.Write(payload); err != nil {
return fmt.Errorf("failed to send yield: %w", err)
}
return nil
}
}

payload, err := s.proto.SendMessage(msgToSend)
if err != nil {
return fmt.Errorf("failed to send yield: %w", err)
}
go func() {
var msgToSend messages.Message
res := endpoint(context.Background(), inv)
if res.Err != "" {
msgToSend = messages.NewError(
int64(invocation.Type()), invocation.RequestID(), map[string]any{}, res.Err, res.Arguments, res.KwArguments,
)
} else {
msgToSend = messages.NewYield(invocation.RequestID(), nil, res.Arguments, res.KwArguments)
}

payload, err := s.proto.SendMessage(msgToSend)
if err != nil {
log.Println("failed to send yield: %w", err)
return
}

if err = s.base.Write(payload); err != nil {
log.Println("failed to send yield: %w", err)
return
}
}()

if err = s.base.Write(payload); err != nil {
return fmt.Errorf("failed to send yield: %w", err)
}
case messages.MessageTypeSubscribed:
subscribed := msg.(*messages.Subscribed)
request, exists := s.subscribeRequests.Load(subscribed.RequestID())
Expand Down Expand Up @@ -323,10 +362,7 @@ func (s *Session) Unregister(registrationID int64) error {
}
}

func (s *Session) Call(ctx context.Context, procedure string, args []any, kwArgs map[string]any,
options map[string]any) (*Result, error) {

call := messages.NewCall(s.idGen.NextID(), options, procedure, args, kwArgs)
func (s *Session) call(ctx context.Context, call *messages.Call) (*Result, error) {
toSend, err := s.proto.SendMessage(call)
if err != nil {
return nil, err
Expand Down Expand Up @@ -356,6 +392,26 @@ func (s *Session) Call(ctx context.Context, procedure string, args []any, kwArgs
}
}

func (s *Session) Call(ctx context.Context, procedure string, args []any, kwArgs map[string]any,
options map[string]any) (*Result, error) {

call := messages.NewCall(s.idGen.NextID(), options, procedure, args, kwArgs)
return s.call(ctx, call)
}

func (s *Session) CallProgress(ctx context.Context, procedure string, args []any, kwArgs map[string]any,
options map[string]any, progressHandler ProgressHandler) (*Result, error) {

call := messages.NewCall(s.idGen.NextID(), options, procedure, args, kwArgs)
if progressHandler == nil {
progressHandler = func(result *Result) {}
}
s.progressHandlers.Store(call.RequestID(), progressHandler)
call.Options()[wampproto.OptionReceiveProgress] = true

return s.call(ctx, call)
}

func (s *Session) Subscribe(topic string, handler EventHandler, options map[string]any) (*Subscription, error) {
subscribe := messages.NewSubscribe(s.idGen.NextID(), options, topic)
toSend, err := s.proto.SendMessage(subscribe)
Expand Down
4 changes: 4 additions & 0 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,15 @@ type Event struct {
Details map[string]any
}

type SendProgress func(arguments []any, kwArguments map[string]any) error

type Invocation struct {
Procedure string
Arguments []any
KwArguments map[string]any
Details map[string]any

SendProgress SendProgress
}

type Result struct {
Expand Down
Loading