Skip to content

Commit

Permalink
codegen: Go 1.23 alias support
Browse files Browse the repository at this point in the history
  • Loading branch information
giautm committed Sep 5, 2024
1 parent 814f7c7 commit c3ba71c
Show file tree
Hide file tree
Showing 11 changed files with 174 additions and 301 deletions.
51 changes: 27 additions & 24 deletions codegen/config/binder.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,12 @@ func (c *Config) NewBinder() *Binder {
}

func (b *Binder) TypePosition(typ types.Type) token.Position {
named, isNamed := typ.(*types.Named)
named, isNamed := code.UnwrapAlias(typ).(*types.Named)
if !isNamed {
return token.Position{
Filename: "unknown",
}
}

return b.ObjectPosition(named.Obj())
}

Expand Down Expand Up @@ -119,8 +118,7 @@ func (b *Binder) DefaultUserObject(name string) (types.Type, error) {
if err != nil {
return nil, err
}

return obj.Type(), nil
return code.UnwrapAlias(obj.Type()), nil
}

func (b *Binder) FindObject(pkgName, typeName string) (types.Object, error) {
Expand Down Expand Up @@ -264,13 +262,13 @@ func (ref *TypeReference) IsPtrToIntf() bool {
}

func (ref *TypeReference) IsNamed() bool {
_, isSlice := ref.GO.(*types.Named)
return isSlice
_, ok := ref.GO.(*types.Named)
return ok
}

func (ref *TypeReference) IsStruct() bool {
_, isStruct := ref.GO.Underlying().(*types.Struct)
return isStruct
_, ok := ref.GO.Underlying().(*types.Struct)
return ok
}

func (ref *TypeReference) IsScalar() bool {
Expand Down Expand Up @@ -362,6 +360,9 @@ func unwrapOmittable(t types.Type) (types.Type, bool) {
}

func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret *TypeReference, err error) {
if bindTarget != nil {
bindTarget = code.UnwrapAlias(bindTarget)
}
if innerType, ok := unwrapOmittable(bindTarget); ok {
if schemaType.NonNull {
return nil, fmt.Errorf("%s is wrapped with Omittable but non-null", schemaType.Name())
Expand All @@ -373,7 +374,7 @@ func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret
}

ref.IsOmittable = true
return ref, err
return ref, nil
}

if !isValid(bindTarget) {
Expand Down Expand Up @@ -433,28 +434,28 @@ func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret
if err != nil {
return nil, err
}

t := code.UnwrapAlias(obj.Type())
if values := b.enumValues(def); len(values) > 0 {
err = b.enumReference(ref, obj, values)
if err != nil {
return nil, err
}
} else if fun, isFunc := obj.(*types.Func); isFunc {
ref.GO = fun.Type().(*types.Signature).Params().At(0).Type()
ref.IsContext = fun.Type().(*types.Signature).Results().At(0).Type().String() == "github.com/99designs/gqlgen/graphql.ContextMarshaler"
ref.GO = t.(*types.Signature).Params().At(0).Type()
ref.IsContext = t.(*types.Signature).Results().At(0).Type().String() == "github.com/99designs/gqlgen/graphql.ContextMarshaler"
ref.Marshaler = fun
ref.Unmarshaler = types.NewFunc(0, fun.Pkg(), "Unmarshal"+typeName, nil)
} else if hasMethod(obj.Type(), "MarshalGQLContext") && hasMethod(obj.Type(), "UnmarshalGQLContext") {
ref.GO = obj.Type()
} else if hasMethod(t, "MarshalGQLContext") && hasMethod(t, "UnmarshalGQLContext") {
ref.GO = t
ref.IsContext = true
ref.IsMarshaler = true
} else if hasMethod(obj.Type(), "MarshalGQL") && hasMethod(obj.Type(), "UnmarshalGQL") {
ref.GO = obj.Type()
} else if hasMethod(t, "MarshalGQL") && hasMethod(t, "UnmarshalGQL") {
ref.GO = t
ref.IsMarshaler = true
} else if underlying := basicUnderlying(obj.Type()); def.IsLeafType() && underlying != nil && underlying.Kind() == types.String {
} else if underlying := basicUnderlying(t); def.IsLeafType() && underlying != nil && underlying.Kind() == types.String {
// TODO delete before v1. Backwards compatibility case for named types wrapping strings (see #595)

ref.GO = obj.Type()
ref.GO = t
ref.CastType = underlying

underlyingRef, err := b.TypeReference(&ast.Type{NamedType: "String"}, nil)
Expand All @@ -465,7 +466,7 @@ func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret
ref.Marshaler = underlyingRef.Marshaler
ref.Unmarshaler = underlyingRef.Unmarshaler
} else {
ref.GO = obj.Type()
ref.GO = t
}

ref.Target = ref.GO
Expand Down Expand Up @@ -587,10 +588,11 @@ func (b *Binder) enumReference(ref *TypeReference, obj types.Object, values map[
return fmt.Errorf("not all enum values are binded for %v", ref.Definition.Name)
}

if fn, ok := obj.Type().(*types.Signature); ok {
ref.GO = fn.Params().At(0).Type()
t := code.UnwrapAlias(obj.Type())
if fn, ok := t.(*types.Signature); ok {
ref.GO = code.UnwrapAlias(fn.Params().At(0).Type())
} else {
ref.GO = obj.Type()
ref.GO = t
}

str, err := b.TypeReference(&ast.Type{NamedType: "String"}, nil)
Expand Down Expand Up @@ -618,9 +620,10 @@ func (b *Binder) enumReference(ref *TypeReference, obj types.Object, values map[
return err
}

if !types.AssignableTo(valueObj.Type(), ref.GO) {
valueTyp := code.UnwrapAlias(valueObj.Type())
if !types.AssignableTo(valueTyp, ref.GO) {
return fmt.Errorf("wrong type: %v, for enum value: %v, expected type: %v, of enum: %v",
valueObj.Type(), value.Name, ref.GO, ref.Definition.Name)
valueTyp, value.Name, ref.GO, ref.Definition.Name)
}

switch valueObj.(type) {
Expand Down
86 changes: 86 additions & 0 deletions codegen/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
goast "go/ast"
"go/types"
"log"
"reflect"
"strconv"
"strings"

Expand Down Expand Up @@ -338,6 +339,91 @@ func (b *builder) findBindInterfaceEmbedsTarget(iface *types.Interface, name str
return found, nil
}

func (b *builder) findBindFieldTarget(in types.Type, name string) (types.Object, error) {
switch t := in.(type) {
case *types.Named:
return b.findBindFieldTarget(t.Underlying(), name)
case *types.Struct:
var found types.Object
for i := 0; i < t.NumFields(); i++ {
field := t.Field(i)
if !field.Exported() || !equalFieldName(field.Name(), name) {
continue
}

if found != nil {
return nil, fmt.Errorf("found more than one matching field to bind for %s", name)
}

found = field
}

return found, nil
}

return nil, nil
}

func (b *builder) findBindEmbedsTarget(in types.Type, name string) (types.Object, error) {
switch t := in.(type) {
case *types.Named:
return b.findBindEmbedsTarget(t.Underlying(), name)
case *types.Struct:
return b.findBindStructEmbedsTarget(t, name)
case *types.Interface:
return b.findBindInterfaceEmbedsTarget(t, name)
}

return nil, nil
}

func (b *builder) findBindStructTagTarget(in types.Type, name string) (types.Object, error) {
if b.Config.StructTag == "" {
return nil, nil
}

switch t := in.(type) {
case *types.Named:
return b.findBindStructTagTarget(t.Underlying(), name)
case *types.Struct:
var found types.Object
for i := 0; i < t.NumFields(); i++ {
field := t.Field(i)
if !field.Exported() || field.Embedded() {
continue
}
tags := reflect.StructTag(t.Tag(i))
if val, ok := tags.Lookup(b.Config.StructTag); ok && equalFieldName(val, name) {
if found != nil {
return nil, fmt.Errorf("tag %s is ambiguous; multiple fields have the same tag value of %s", b.Config.StructTag, val)
}

found = field
}
}

return found, nil
}

return nil, nil
}

func (b *builder) findBindMethodTarget(in types.Type, name string) (types.Object, error) {
switch t := in.(type) {
case *types.Named:
if _, ok := t.Underlying().(*types.Interface); ok {
return b.findBindMethodTarget(t.Underlying(), name)
}

return b.findBindMethoderTarget(t.Method, t.NumMethods(), name)
case *types.Interface:
// FIX-ME: Should use ExplicitMethod here? What's the difference?
return b.findBindMethoderTarget(t.Method, t.NumMethods(), name)
}

return nil, nil
}

func (f *Field) HasDirectives() bool {
return len(f.ImplDirectives()) > 0
}
Expand Down
102 changes: 0 additions & 102 deletions codegen/field_1.23.go

This file was deleted.

Loading

0 comments on commit c3ba71c

Please sign in to comment.