Skip to content

Commit

Permalink
Add tests and fix cyclic dependency issues
Browse files Browse the repository at this point in the history
  • Loading branch information
shuheiktgw committed Sep 26, 2023
1 parent ea0d5ac commit 6b7f5ec
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 13 deletions.
2 changes: 1 addition & 1 deletion resolver/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ func (c *context) argIndex() int {
return c.argIdx
}

func (c *context) addError(e *LocationError) {
func (c *context) addError(e error) {
c.errorBuilder.add(e)
}

Expand Down
12 changes: 5 additions & 7 deletions resolver/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ func (n *MessageDependencyGraphNode) ExpectedMessageArguments() []*Argument {
}

// CreateMessageDependencyGraph creates a dependency graph for all messages with message options defined.
func CreateMessageDependencyGraph(ctx *context, msgs []*Message) *MessageDependencyGraph {
// If a circular reference occurs, return an error.
func CreateMessageDependencyGraph(ctx *context, msgs []*Message) (*MessageDependencyGraph, error) {
msgToNode := map[*Message]*MessageDependencyGraphNode{}
for _, msg := range msgs {
if msg.Rule == nil {
Expand Down Expand Up @@ -92,7 +93,7 @@ func CreateMessageDependencyGraph(ctx *context, msgs []*Message) *MessageDepende
}
}
if len(roots) == 0 {
return nil
return nil, nil
}
sort.Slice(roots, func(i, j int) bool {
return roots[i].Message.Name < roots[j].Message.Name
Expand All @@ -102,11 +103,9 @@ func CreateMessageDependencyGraph(ctx *context, msgs []*Message) *MessageDepende
RootArgs: map[*MessageDependencyGraphNode]*Message{},
}
if err := validateMessageGraph(graph); err != nil {
fileName := roots[0].Message.File.Name
ctx.addError(ErrWithLocation(err.Error(), source.FileLocation(fileName)))
return nil
return nil, err
}
return graph
return graph, nil
}

type MessageRuleDependencyGraph struct {
Expand Down Expand Up @@ -214,7 +213,6 @@ type MessageRuleDependencyGraphNode struct {
// CreateMessageRuleDependencyGraph construct a dependency graph using the name-based reference dependencies used in the method calls
// and the arguments used to retrieve the dependency messages.
// Requires reference resolution for arguments that use prior name-based references.
// If a circular reference occurs, return an error.
func CreateMessageRuleDependencyGraph(ctx *context, baseMsg *Message, rule *MessageRule) *MessageRuleDependencyGraph {
msgToNode := map[*Message]*MessageRuleDependencyGraphNode{}
if rule.MethodCall != nil && rule.MethodCall.Response != nil {
Expand Down
19 changes: 14 additions & 5 deletions resolver/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ func (r *Resolver) Resolve() (*Result, error) {
// In order to return multiple errors with source code location information,
// we add all errors to the context when they occur.
// Therefore, functions called from Resolve() do not return errors directly.
// Instead, it must return all errors captured by context in ctx.error().
// Instead, it must return all errors captured by context in ctx.error()
// except for those that fail subsequent processing, such as cyclic dependencies
ctx := newContext()

r.resolvePackageAndFileReference(ctx, append(r.files, stdFileDescriptors()...))
Expand All @@ -93,7 +94,11 @@ func (r *Resolver) Resolve() (*Result, error) {
return &Result{Warnings: ctx.warnings()}, ctx.error()
}

r.resolveMessageArgumentReference(ctx, files)
err := r.resolveMessageArgumentReference(ctx, files)
if err != nil {
ctx.addError(err)
return &Result{Warnings: ctx.warnings()}, ctx.error()
}

services := r.servicesWithRule(ctx, files)
return &Result{Services: services, Warnings: ctx.warnings()}, ctx.error()
Expand Down Expand Up @@ -2034,13 +2039,16 @@ func (r *Resolver) setValueMessageArgumentReferenceForMessageFieldValue(fields m
}

// resolveresolveMessageArgumentReference constructs message arguments using a dependency graph and assigns them to each message.
func (r *Resolver) resolveMessageArgumentReference(ctx *context, files []*File) {
func (r *Resolver) resolveMessageArgumentReference(ctx *context, files []*File) error {
r.resolveMethodResponseMessageArgument(ctx, files)

// create a dependency graph for all messages.
graph := CreateMessageDependencyGraph(ctx, r.allMessages(files))
graph, err := CreateMessageDependencyGraph(ctx, r.allMessages(files))
if err != nil {
return err
}
if graph == nil {
return
return nil
}
for _, root := range graph.Roots {
reqMsg := r.lookupRequestMessageFromResponseMessage(root.Message)
Expand Down Expand Up @@ -2072,6 +2080,7 @@ func (r *Resolver) resolveMessageArgumentReference(ctx *context, files []*File)
r.resolveValueMessageArgumentReference(ctx, msg)
}
}
return nil
}

func (r *Resolver) resolveMethodResponseMessageArgument(ctx *context, files []*File) {
Expand Down
35 changes: 35 additions & 0 deletions resolver/resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package resolver_test

import (
"path/filepath"
"strings"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -2140,6 +2141,40 @@ func TestLiteral(t *testing.T) {
}
}

func TestValidation(t *testing.T) {
tests := []struct {
desc string
file string
expected string
}{
{
desc: "message nodes have cyclic dependency",
file: "message_cyclic_dependency.proto",
expected: `found cyclic dependency in "org.federation.A" message`,
},
{
desc: "message rule nodes have cyclic dependency",
file: "message_rule_cyclic_dependency.proto",
expected: `found cyclic dependency in "org.federation.B" message`,
},
}

for _, tc := range tests {
t.Run(tc.desc, func(t *testing.T) {
r := resolver.New(testutil.Compile(t, filepath.Join(testutil.RepoRoot(), "validator", "testdata", tc.file)))
_, err := r.Resolve()

if err == nil {
t.Fatal("expected to receive an error but got nil")
}

if !strings.Contains(err.Error(), tc.expected) {
t.Fatalf("error %q should have contained substrings %q", err.Error(), tc.expected)
}
})
}
}

func getUserProtoBuilder(t *testing.T) *testutil.FileBuilder {
ub := testutil.NewFileBuilder("user.proto")
ref := testutil.NewBuilderReferenceManager(ub)
Expand Down
62 changes: 62 additions & 0 deletions validator/testdata/message_cyclic_dependency.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
syntax = "proto3";

package org.federation;

import "grpc/federation/federation.proto";

option go_package = "example/federation;federation";

service FederationService {
option (grpc.federation.service) = {};
rpc Get(GetRequest) returns (GetResponse) {};
}

message GetRequest {}

message GetResponse {
option (grpc.federation.message) = {
messages: [
{ name: "a", message: "A" },
{ name: "b", message: "B" }
]
};
string aaaname = 1 [(grpc.federation.field).by = "a.aaaname"];
string bname = 2 [(grpc.federation.field).by = "b.name"];
}

message A {
option (grpc.federation.message) = {
messages: [
{ name: "aa", message: "AA" },
{ name: "ab", message: "AB" }
]
};
string aaaname = 1 [(grpc.federation.field).by = "aa.aaaname"];
string abname = 2 [(grpc.federation.field).by = "ab.name"];
}

message AA {
option (grpc.federation.message) = {
messages: [
{ name: "aaa", message: "AAA" }
]
};
string aaaname = 1 [(grpc.federation.field).by = "aaa.name"];
}

message AB {
string name = 1 [(grpc.federation.field).string = "ab"];
}

message AAA {
option (grpc.federation.message) = {
messages: [
{ name: "a", message: "A" }
]
};
string name = 1 [(grpc.federation.field).by = "a.name"];
}

message B {
string name = 1 [(grpc.federation.field).string = "b"];
}
48 changes: 48 additions & 0 deletions validator/testdata/message_rule_cyclic_dependency.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
syntax = "proto3";

package org.federation;

import "grpc/federation/federation.proto";

option go_package = "example/federation;federation";

service FederationService {
option (grpc.federation.service) = {};
rpc Get(FederatedRequest) returns (FederatedResponse) {};
}

message FederatedRequest {
string id = 1;
}

message FederatedResponse {
option (grpc.federation.message) = {
messages: [
{name: "a", message: "A", args: [{name: "id", by: "$.id"}]},
{name: "b", message: "B", args: [{name: "aid", by: "a.id"}, {name: "did", by: "d.id"}]},
{name: "c", message: "C", args: [{name: "id", by: "b.id"}]},
{name: "d", message: "D", args: [{name: "id", by: "c.id"}]}
]
};

string aid = 1 [(grpc.federation.field).by = "a.id"];
string bid = 2 [(grpc.federation.field).by = "b.id"];
string cid = 3 [(grpc.federation.field).by = "c.id"];
string did = 4 [(grpc.federation.field).by = "d.id"];
}

message A {
string id = 1 [(grpc.federation.field).string = "ID"];
}

message B {
string id = 1 [(grpc.federation.field).string = "ID"];
}

message C {
string id = 1 [(grpc.federation.field).string = "ID"];
}

message D {
string id = 1 [(grpc.federation.field).string = "ID"];
}

0 comments on commit 6b7f5ec

Please sign in to comment.