From 280e547f8ae244efaae38b9b64651eaacc20f4d6 Mon Sep 17 00:00:00 2001 From: Vladislav Fursov Date: Wed, 1 Nov 2023 23:04:39 +0400 Subject: [PATCH] fix optional proto fields false positive --- processor.go | 94 +++++++++++++++++++++++++++++- protogetter.go | 1 + testdata/proto/test.pb.go | 117 +++++++++++++++++++++++++++++--------- testdata/proto/test.proto | 7 +++ testdata/test.go | 11 ++++ testdata/test.go.golden | 11 ++++ 6 files changed, 212 insertions(+), 29 deletions(-) diff --git a/processor.go b/processor.go index 167248d..e42b38d 100644 --- a/processor.go +++ b/processor.go @@ -87,6 +87,50 @@ func (c *processor) process(n ast.Node) (*Result, error) { c.writeFrom("*") c.processInner(x.X) + case *ast.BinaryExpr: + // Check if the expression is a comparison. + if x.Op != token.EQL && x.Op != token.NEQ { + return &Result{}, nil + } + + // Check if one of the operands is nil. + + xIdent, xOk := x.X.(*ast.Ident) + yIdent, yOk := x.Y.(*ast.Ident) + + xIsNil := xOk && xIdent.Name == "nil" + yIsNil := yOk && yIdent.Name == "nil" + + if !xIsNil && !yIsNil { + return &Result{}, nil + } + + // Extract the non-nil operand for further checks + + var expr ast.Expr + if xIsNil { + expr = x.Y + } else { + expr = x.X + } + + se, ok := expr.(*ast.SelectorExpr) + if !ok { + return &Result{}, nil + } + + if !isProtoMessage(c.info, se.X) { + return &Result{}, nil + } + + // Check if the Getter function of the protobuf message returns a pointer. + hasPointer, ok := getterResultHasPointer(c.info, se.X, se.Sel.Name) + if !ok || hasPointer { + return &Result{}, nil + } + + c.filter.AddPos(x.X.Pos()) + default: return nil, fmt.Errorf("not implemented for type: %s (%s)", reflect.TypeOf(x), formatNode(n)) } @@ -225,14 +269,14 @@ func isProtoMessage(info *types.Info, expr ast.Expr) bool { return false } -func methodIsExists(info *types.Info, x ast.Expr, name string) bool { +func typesNamed(info *types.Info, x ast.Expr) (*types.Named, bool) { if info == nil { - return false + return nil, false } t := info.TypeOf(x) if t == nil { - return false + return nil, false } ptr, ok := t.Underlying().(*types.Pointer) @@ -241,6 +285,15 @@ func methodIsExists(info *types.Info, x ast.Expr, name string) bool { } named, ok := t.(*types.Named) + if !ok { + return nil, false + } + + return named, true +} + +func methodIsExists(info *types.Info, x ast.Expr, name string) bool { + named, ok := typesNamed(info, x) if !ok { return false } @@ -253,3 +306,38 @@ func methodIsExists(info *types.Info, x ast.Expr, name string) bool { return false } + +func getterResultHasPointer(info *types.Info, x ast.Expr, name string) (hasPointer, ok bool) { + named, ok := typesNamed(info, x) + if !ok { + return false, false + } + + for i := 0; i < named.NumMethods(); i++ { + method := named.Method(i) + if method.Name() != "Get"+name { + continue + } + + var sig *types.Signature + sig, ok = method.Type().(*types.Signature) + if !ok { + return false, false + } + + results := sig.Results() + if results.Len() == 0 { + return false, false + } + + firstType := results.At(0) + _, ok = firstType.Type().(*types.Pointer) + if !ok { + return false, true + } + + return true, true + } + + return false, false +} diff --git a/protogetter.go b/protogetter.go index 0061994..bb84c6b 100644 --- a/protogetter.go +++ b/protogetter.go @@ -97,6 +97,7 @@ func Run(pass *analysis.Pass, cfg *Config) ([]Issue, error) { nodeTypes := []ast.Node{ (*ast.AssignStmt)(nil), + (*ast.BinaryExpr)(nil), (*ast.CallExpr)(nil), (*ast.SelectorExpr)(nil), (*ast.StarExpr)(nil), diff --git a/testdata/proto/test.pb.go b/testdata/proto/test.pb.go index a34e426..31e3b8d 100644 --- a/testdata/proto/test.pb.go +++ b/testdata/proto/test.pb.go @@ -20,6 +20,53 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) +// Optional enum +type Test_OEnum int32 + +const ( + Test_O_ENUM1 Test_OEnum = 0 + Test_O_ENUM2 Test_OEnum = 1 +) + +// Enum value maps for Test_OEnum. +var ( + Test_OEnum_name = map[int32]string{ + 0: "O_ENUM1", + 1: "O_ENUM2", + } + Test_OEnum_value = map[string]int32{ + "O_ENUM1": 0, + "O_ENUM2": 1, + } +) + +func (x Test_OEnum) Enum() *Test_OEnum { + p := new(Test_OEnum) + *p = x + return p +} + +func (x Test_OEnum) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (Test_OEnum) Descriptor() protoreflect.EnumDescriptor { + return file_test_proto_enumTypes[0].Descriptor() +} + +func (Test_OEnum) Type() protoreflect.EnumType { + return &file_test_proto_enumTypes[0] +} + +func (x Test_OEnum) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use Test_OEnum.Descriptor instead. +func (Test_OEnum) EnumDescriptor() ([]byte, []int) { + return file_test_proto_rawDescGZIP(), []int{0, 0} +} + type Test struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -37,7 +84,8 @@ type Test struct { Embedded *Embedded `protobuf:"bytes,10,opt,name=embedded,proto3" json:"embedded,omitempty"` RepeatedEmbeddeds []*Embedded `protobuf:"bytes,11,rep,name=repeated_embeddeds,json=repeatedEmbeddeds,proto3" json:"repeated_embeddeds,omitempty"` // issue #4 - OptBool *bool `protobuf:"varint,12,opt,name=opt_bool,json=optBool,proto3,oneof" json:"opt_bool,omitempty"` + OptBool *bool `protobuf:"varint,12,opt,name=opt_bool,json=optBool,proto3,oneof" json:"opt_bool,omitempty"` + OptEnum *Test_OEnum `protobuf:"varint,13,opt,name=opt_enum,json=optEnum,proto3,enum=Test_OEnum,oneof" json:"opt_enum,omitempty"` } func (x *Test) Reset() { @@ -156,6 +204,13 @@ func (x *Test) GetOptBool() bool { return false } +func (x *Test) GetOptEnum() Test_OEnum { + if x != nil && x.OptEnum != nil { + return *x.OptEnum + } + return Test_O_ENUM1 +} + type Embedded struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -214,7 +269,7 @@ func (x *Embedded) GetEmbedded() *Embedded { var File_test_proto protoreflect.FileDescriptor var file_test_proto_rawDesc = []byte{ - 0x0a, 0x0a, 0x74, 0x65, 0x73, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xa2, 0x02, 0x0a, + 0x0a, 0x0a, 0x74, 0x65, 0x73, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xff, 0x02, 0x0a, 0x04, 0x54, 0x65, 0x73, 0x74, 0x12, 0x0c, 0x0a, 0x01, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x01, 0x52, 0x01, 0x64, 0x12, 0x0c, 0x0a, 0x01, 0x66, 0x18, 0x02, 0x20, 0x01, 0x28, 0x02, 0x52, 0x01, 0x66, 0x12, 0x10, 0x0a, 0x03, 0x69, 0x33, 0x32, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x03, @@ -232,17 +287,23 @@ var file_test_proto_rawDesc = []byte{ 0x64, 0x52, 0x11, 0x72, 0x65, 0x70, 0x65, 0x61, 0x74, 0x65, 0x64, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x65, 0x64, 0x73, 0x12, 0x1e, 0x0a, 0x08, 0x6f, 0x70, 0x74, 0x5f, 0x62, 0x6f, 0x6f, 0x6c, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x07, 0x6f, 0x70, 0x74, 0x42, 0x6f, 0x6f, - 0x6c, 0x88, 0x01, 0x01, 0x42, 0x0b, 0x0a, 0x09, 0x5f, 0x6f, 0x70, 0x74, 0x5f, 0x62, 0x6f, 0x6f, - 0x6c, 0x22, 0x3f, 0x0a, 0x08, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x65, 0x64, 0x12, 0x0c, 0x0a, - 0x01, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x01, 0x73, 0x12, 0x25, 0x0a, 0x08, 0x65, - 0x6d, 0x62, 0x65, 0x64, 0x64, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x09, 0x2e, - 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x65, 0x64, 0x52, 0x08, 0x65, 0x6d, 0x62, 0x65, 0x64, 0x64, - 0x65, 0x64, 0x32, 0x1f, 0x0a, 0x07, 0x54, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67, 0x12, 0x14, 0x0a, - 0x04, 0x63, 0x61, 0x6c, 0x6c, 0x12, 0x05, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x1a, 0x05, 0x2e, 0x54, - 0x65, 0x73, 0x74, 0x42, 0x30, 0x5a, 0x2e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, - 0x6d, 0x2f, 0x67, 0x68, 0x6f, 0x73, 0x74, 0x69, 0x61, 0x6d, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x67, 0x65, 0x74, 0x74, 0x65, 0x72, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x64, 0x61, 0x74, 0x61, 0x2f, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x6c, 0x88, 0x01, 0x01, 0x12, 0x2b, 0x0a, 0x08, 0x6f, 0x70, 0x74, 0x5f, 0x65, 0x6e, 0x75, 0x6d, + 0x18, 0x0d, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0b, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x2e, 0x4f, 0x45, + 0x6e, 0x75, 0x6d, 0x48, 0x01, 0x52, 0x07, 0x6f, 0x70, 0x74, 0x45, 0x6e, 0x75, 0x6d, 0x88, 0x01, + 0x01, 0x22, 0x21, 0x0a, 0x05, 0x4f, 0x45, 0x6e, 0x75, 0x6d, 0x12, 0x0b, 0x0a, 0x07, 0x4f, 0x5f, + 0x45, 0x4e, 0x55, 0x4d, 0x31, 0x10, 0x00, 0x12, 0x0b, 0x0a, 0x07, 0x4f, 0x5f, 0x45, 0x4e, 0x55, + 0x4d, 0x32, 0x10, 0x01, 0x42, 0x0b, 0x0a, 0x09, 0x5f, 0x6f, 0x70, 0x74, 0x5f, 0x62, 0x6f, 0x6f, + 0x6c, 0x42, 0x0b, 0x0a, 0x09, 0x5f, 0x6f, 0x70, 0x74, 0x5f, 0x65, 0x6e, 0x75, 0x6d, 0x22, 0x3f, + 0x0a, 0x08, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x65, 0x64, 0x12, 0x0c, 0x0a, 0x01, 0x73, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x01, 0x73, 0x12, 0x25, 0x0a, 0x08, 0x65, 0x6d, 0x62, 0x65, + 0x64, 0x64, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x09, 0x2e, 0x45, 0x6d, 0x62, + 0x65, 0x64, 0x64, 0x65, 0x64, 0x52, 0x08, 0x65, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x65, 0x64, 0x32, + 0x1f, 0x0a, 0x07, 0x54, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67, 0x12, 0x14, 0x0a, 0x04, 0x63, 0x61, + 0x6c, 0x6c, 0x12, 0x05, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x1a, 0x05, 0x2e, 0x54, 0x65, 0x73, 0x74, + 0x42, 0x30, 0x5a, 0x2e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, + 0x68, 0x6f, 0x73, 0x74, 0x69, 0x61, 0x6d, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x67, 0x65, 0x74, + 0x74, 0x65, 0x72, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x64, 0x61, 0x74, 0x61, 0x2f, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -257,22 +318,25 @@ func file_test_proto_rawDescGZIP() []byte { return file_test_proto_rawDescData } +var file_test_proto_enumTypes = make([]protoimpl.EnumInfo, 1) var file_test_proto_msgTypes = make([]protoimpl.MessageInfo, 2) var file_test_proto_goTypes = []interface{}{ - (*Test)(nil), // 0: Test - (*Embedded)(nil), // 1: Embedded + (Test_OEnum)(0), // 0: Test.OEnum + (*Test)(nil), // 1: Test + (*Embedded)(nil), // 2: Embedded } var file_test_proto_depIdxs = []int32{ - 1, // 0: Test.embedded:type_name -> Embedded - 1, // 1: Test.repeated_embeddeds:type_name -> Embedded - 1, // 2: Embedded.embedded:type_name -> Embedded - 0, // 3: Testing.call:input_type -> Test - 0, // 4: Testing.call:output_type -> Test - 4, // [4:5] is the sub-list for method output_type - 3, // [3:4] is the sub-list for method input_type - 3, // [3:3] is the sub-list for extension type_name - 3, // [3:3] is the sub-list for extension extendee - 0, // [0:3] is the sub-list for field type_name + 2, // 0: Test.embedded:type_name -> Embedded + 2, // 1: Test.repeated_embeddeds:type_name -> Embedded + 0, // 2: Test.opt_enum:type_name -> Test.OEnum + 2, // 3: Embedded.embedded:type_name -> Embedded + 1, // 4: Testing.call:input_type -> Test + 1, // 5: Testing.call:output_type -> Test + 5, // [5:6] is the sub-list for method output_type + 4, // [4:5] is the sub-list for method input_type + 4, // [4:4] is the sub-list for extension type_name + 4, // [4:4] is the sub-list for extension extendee + 0, // [0:4] is the sub-list for field type_name } func init() { file_test_proto_init() } @@ -312,13 +376,14 @@ func file_test_proto_init() { File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_test_proto_rawDesc, - NumEnums: 0, + NumEnums: 1, NumMessages: 2, NumExtensions: 0, NumServices: 1, }, GoTypes: file_test_proto_goTypes, DependencyIndexes: file_test_proto_depIdxs, + EnumInfos: file_test_proto_enumTypes, MessageInfos: file_test_proto_msgTypes, }.Build() File_test_proto = out.File diff --git a/testdata/proto/test.proto b/testdata/proto/test.proto index 2d38aab..c613b97 100644 --- a/testdata/proto/test.proto +++ b/testdata/proto/test.proto @@ -17,6 +17,13 @@ message Test { // issue #4 optional bool opt_bool = 12; + + // Optional enum + enum OEnum { + O_ENUM1 = 0; + O_ENUM2 = 1; + } + optional OEnum opt_enum = 13; } message Embedded { diff --git a/testdata/test.go b/testdata/test.go index 3ef25bf..4df5698 100644 --- a/testdata/test.go +++ b/testdata/test.go @@ -79,6 +79,17 @@ func testInvalid(t *proto.Test) { t.Embedded.SetS("test") // want `avoid direct access to proto field t\.Embedded\.SetS\("test"\), use t\.GetEmbedded\(\)\.SetS\("test"\) instead` t.Embedded.SetMap(map[string]string{"test": "test"}) // want `avoid direct access to proto field t\.Embedded\.SetMap\(map\[string\]string{"test": "test"}\), use t\.GetEmbedded\(\)\.SetMap\(map\[string\]string{"test": "test"}\) instead` + + // Optional enum + switch *t.OptEnum { // want `avoid direct access to proto field \*t\.OptEnum, use t\.GetOptEnum\(\) instead` + case proto.Test_O_ENUM1: + case proto.Test_O_ENUM2: + } + + if t.OptEnum != nil && *t.OptEnum == proto.Test_O_ENUM1 { // want `avoid direct access to proto field \*t\.OptEnum, use t\.GetOptEnum\(\) instead` + } + + _ = *t.OptEnum // want `avoid direct access to proto field \*t\.OptEnum, use t\.GetOptEnum\(\) instead` } func testValid(t *proto.Test) { diff --git a/testdata/test.go.golden b/testdata/test.go.golden index 134b42e..d98bcc3 100644 --- a/testdata/test.go.golden +++ b/testdata/test.go.golden @@ -79,6 +79,17 @@ func testInvalid(t *proto.Test) { t.GetEmbedded().SetS("test") // want `avoid direct access to proto field t\.Embedded\.SetS\("test"\), use t\.GetEmbedded\(\)\.SetS\("test"\) instead` t.GetEmbedded().SetMap(map[string]string{"test": "test"}) // want `avoid direct access to proto field t\.Embedded\.SetMap\(map\[string\]string{"test": "test"}\), use t\.GetEmbedded\(\)\.SetMap\(map\[string\]string{"test": "test"}\) instead` + + // Optional enum + switch t.GetOptEnum() { // want `avoid direct access to proto field \*t\.OptEnum, use t\.GetOptEnum\(\) instead` + case proto.Test_O_ENUM1: + case proto.Test_O_ENUM2: + } + + if t.OptEnum != nil && t.GetOptEnum() == proto.Test_O_ENUM1 { // want `avoid direct access to proto field \*t\.OptEnum, use t\.GetOptEnum\(\) instead` + } + + _ = t.GetOptEnum() // want `avoid direct access to proto field \*t\.OptEnum, use t\.GetOptEnum\(\) instead` } func testValid(t *proto.Test) {