diff --git a/codegen/config/binder.go b/codegen/config/binder.go index 8e71e0ca7a0..499483229c6 100644 --- a/codegen/config/binder.go +++ b/codegen/config/binder.go @@ -36,13 +36,12 @@ func (c *Config) NewBinder() *Binder { } func (b *Binder) TypePosition(typ types.Type) token.Position { - named, isNamed := typ.(*types.Named) + named, isNamed := code.UnwrapAlias(typ).(*types.Named) if !isNamed { return token.Position{ Filename: "unknown", } } - return b.ObjectPosition(named.Obj()) } @@ -264,13 +263,13 @@ func (ref *TypeReference) IsPtrToIntf() bool { } func (ref *TypeReference) IsNamed() bool { - _, isSlice := ref.GO.(*types.Named) - return isSlice + _, ok := ref.GO.(*types.Named) + return ok } func (ref *TypeReference) IsStruct() bool { - _, isStruct := ref.GO.Underlying().(*types.Struct) - return isStruct + _, ok := ref.GO.Underlying().(*types.Struct) + return ok } func (ref *TypeReference) IsScalar() bool { @@ -433,28 +432,28 @@ func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret if err != nil { return nil, err } - + t := code.UnwrapAlias(obj.Type()) if values := b.enumValues(def); len(values) > 0 { err = b.enumReference(ref, obj, values) if err != nil { return nil, err } } else if fun, isFunc := obj.(*types.Func); isFunc { - ref.GO = fun.Type().(*types.Signature).Params().At(0).Type() - ref.IsContext = fun.Type().(*types.Signature).Results().At(0).Type().String() == "github.com/99designs/gqlgen/graphql.ContextMarshaler" + ref.GO = t.(*types.Signature).Params().At(0).Type() + ref.IsContext = t.(*types.Signature).Results().At(0).Type().String() == "github.com/99designs/gqlgen/graphql.ContextMarshaler" ref.Marshaler = fun ref.Unmarshaler = types.NewFunc(0, fun.Pkg(), "Unmarshal"+typeName, nil) - } else if hasMethod(obj.Type(), "MarshalGQLContext") && hasMethod(obj.Type(), "UnmarshalGQLContext") { - ref.GO = obj.Type() + } else if hasMethod(t, "MarshalGQLContext") && hasMethod(t, "UnmarshalGQLContext") { + ref.GO = t ref.IsContext = true ref.IsMarshaler = true - } else if hasMethod(obj.Type(), "MarshalGQL") && hasMethod(obj.Type(), "UnmarshalGQL") { - ref.GO = obj.Type() + } else if hasMethod(t, "MarshalGQL") && hasMethod(t, "UnmarshalGQL") { + ref.GO = t ref.IsMarshaler = true - } else if underlying := basicUnderlying(obj.Type()); def.IsLeafType() && underlying != nil && underlying.Kind() == types.String { + } else if underlying := basicUnderlying(t); def.IsLeafType() && underlying != nil && underlying.Kind() == types.String { // TODO delete before v1. Backwards compatibility case for named types wrapping strings (see #595) - ref.GO = obj.Type() + ref.GO = t ref.CastType = underlying underlyingRef, err := b.TypeReference(&ast.Type{NamedType: "String"}, nil) @@ -465,7 +464,7 @@ func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret ref.Marshaler = underlyingRef.Marshaler ref.Unmarshaler = underlyingRef.Unmarshaler } else { - ref.GO = obj.Type() + ref.GO = t } ref.Target = ref.GO @@ -587,10 +586,11 @@ func (b *Binder) enumReference(ref *TypeReference, obj types.Object, values map[ return fmt.Errorf("not all enum values are binded for %v", ref.Definition.Name) } - if fn, ok := obj.Type().(*types.Signature); ok { + t := code.UnwrapAlias(obj.Type()) + if fn, ok := t.(*types.Signature); ok { ref.GO = fn.Params().At(0).Type() } else { - ref.GO = obj.Type() + ref.GO = t } str, err := b.TypeReference(&ast.Type{NamedType: "String"}, nil) diff --git a/codegen/field.go b/codegen/field.go index 4a492ff9a0c..151a5bcfd9f 100644 --- a/codegen/field.go +++ b/codegen/field.go @@ -6,6 +6,7 @@ import ( goast "go/ast" "go/types" "log" + "reflect" "strconv" "strings" @@ -15,6 +16,7 @@ import ( "github.com/99designs/gqlgen/codegen/config" "github.com/99designs/gqlgen/codegen/templates" + "github.com/99designs/gqlgen/internal/code" ) type Field struct { @@ -235,7 +237,7 @@ func (b *builder) bindField(obj *Object, f *Field) (errret error) { func (b *builder) findBindTarget(t types.Type, name string) (types.Object, error) { // NOTE: a struct tag will override both methods and fields // Bind to struct tag - found, err := b.findBindStructTagTarget(t, name) + found, err := b.findBindStructTagTarget(code.UnwrapAlias(t), name) if found != nil || err != nil { return found, err } @@ -247,7 +249,7 @@ func (b *builder) findBindTarget(t types.Type, name string) (types.Object, error } // Search for a field to bind to - foundField, err := b.findBindFieldTarget(t, name) + foundField, err := b.findBindFieldTarget(code.UnwrapAlias(t), name) if err != nil { return nil, err } @@ -265,7 +267,7 @@ func (b *builder) findBindTarget(t types.Type, name string) (types.Object, error } // Search embeds - return b.findBindEmbedsTarget(t, name) + return b.findBindEmbedsTarget(code.UnwrapAlias(t), name) } func (b *builder) findBindMethoderTarget(methodFunc func(i int) *types.Func, methodCount int, name string) (types.Object, error) { @@ -338,6 +340,91 @@ func (b *builder) findBindInterfaceEmbedsTarget(iface *types.Interface, name str return found, nil } +func (b *builder) findBindFieldTarget(in types.Type, name string) (types.Object, error) { + switch t := in.(type) { + case *types.Named: + return b.findBindFieldTarget(t.Underlying(), name) + case *types.Struct: + var found types.Object + for i := 0; i < t.NumFields(); i++ { + field := t.Field(i) + if !field.Exported() || !equalFieldName(field.Name(), name) { + continue + } + + if found != nil { + return nil, fmt.Errorf("found more than one matching field to bind for %s", name) + } + + found = field + } + + return found, nil + } + + return nil, nil +} + +func (b *builder) findBindEmbedsTarget(in types.Type, name string) (types.Object, error) { + switch t := in.(type) { + case *types.Named: + return b.findBindEmbedsTarget(t.Underlying(), name) + case *types.Struct: + return b.findBindStructEmbedsTarget(t, name) + case *types.Interface: + return b.findBindInterfaceEmbedsTarget(t, name) + } + + return nil, nil +} + +func (b *builder) findBindStructTagTarget(in types.Type, name string) (types.Object, error) { + if b.Config.StructTag == "" { + return nil, nil + } + + switch t := in.(type) { + case *types.Named: + return b.findBindStructTagTarget(t.Underlying(), name) + case *types.Struct: + var found types.Object + for i := 0; i < t.NumFields(); i++ { + field := t.Field(i) + if !field.Exported() || field.Embedded() { + continue + } + tags := reflect.StructTag(t.Tag(i)) + if val, ok := tags.Lookup(b.Config.StructTag); ok && equalFieldName(val, name) { + if found != nil { + return nil, fmt.Errorf("tag %s is ambiguous; multiple fields have the same tag value of %s", b.Config.StructTag, val) + } + + found = field + } + } + + return found, nil + } + + return nil, nil +} + +func (b *builder) findBindMethodTarget(in types.Type, name string) (types.Object, error) { + switch t := in.(type) { + case *types.Named: + if _, ok := t.Underlying().(*types.Interface); ok { + return b.findBindMethodTarget(t.Underlying(), name) + } + + return b.findBindMethoderTarget(t.Method, t.NumMethods(), name) + case *types.Interface: + // FIX-ME: Should use ExplicitMethod here? What's the difference? + return b.findBindMethoderTarget(t.Method, t.NumMethods(), name) + } + + return nil, nil +} + func (f *Field) HasDirectives() bool { return len(f.ImplDirectives()) > 0 } diff --git a/codegen/field_1.23.go b/codegen/field_1.23.go deleted file mode 100644 index 105940f2cf9..00000000000 --- a/codegen/field_1.23.go +++ /dev/null @@ -1,102 +0,0 @@ -//go:build go1.23 - -package codegen - -import ( - "fmt" - "go/types" - "reflect" -) - -func (b *builder) findBindFieldTarget(in types.Type, name string) (types.Object, error) { - switch t := in.(type) { - case *types.Alias: - return b.findBindFieldTarget(t.Rhs(), name) - case *types.Named: - return b.findBindFieldTarget(t.Underlying(), name) - case *types.Struct: - var found types.Object - for i := 0; i < t.NumFields(); i++ { - field := t.Field(i) - if !field.Exported() || !equalFieldName(field.Name(), name) { - continue - } - - if found != nil { - return nil, fmt.Errorf("found more than one matching field to bind for %s", name) - } - - found = field - } - - return found, nil - } - - return nil, nil -} - -func (b *builder) findBindEmbedsTarget(in types.Type, name string) (types.Object, error) { - switch t := in.(type) { - case *types.Alias: - return b.findBindEmbedsTarget(t.Rhs(), name) - case *types.Named: - return b.findBindEmbedsTarget(t.Underlying(), name) - case *types.Struct: - return b.findBindStructEmbedsTarget(t, name) - case *types.Interface: - return b.findBindInterfaceEmbedsTarget(t, name) - } - - return nil, nil -} - -func (b *builder) findBindStructTagTarget(in types.Type, name string) (types.Object, error) { - if b.Config.StructTag == "" { - return nil, nil - } - - switch t := in.(type) { - case *types.Alias: - return b.findBindStructTagTarget(t.Rhs(), name) - case *types.Named: - return b.findBindStructTagTarget(t.Underlying(), name) - case *types.Struct: - var found types.Object - for i := 0; i < t.NumFields(); i++ { - field := t.Field(i) - if !field.Exported() || field.Embedded() { - continue - } - tags := reflect.StructTag(t.Tag(i)) - if val, ok := tags.Lookup(b.Config.StructTag); ok && equalFieldName(val, name) { - if found != nil { - return nil, fmt.Errorf("tag %s is ambiguous; multiple fields have the same tag value of %s", b.Config.StructTag, val) - } - - found = field - } - } - - return found, nil - } - - return nil, nil -} - -func (b *builder) findBindMethodTarget(in types.Type, name string) (types.Object, error) { - switch t := in.(type) { - case *types.Alias: - return b.findBindMethodTarget(t.Rhs(), name) - case *types.Named: - if _, ok := t.Underlying().(*types.Interface); ok { - return b.findBindMethodTarget(t.Underlying(), name) - } - - return b.findBindMethoderTarget(t.Method, t.NumMethods(), name) - case *types.Interface: - // FIX-ME: Should use ExplicitMethod here? What's the difference? - return b.findBindMethoderTarget(t.Method, t.NumMethods(), name) - } - - return nil, nil -} diff --git a/codegen/field_other.go b/codegen/field_other.go deleted file mode 100644 index 2e5e0d51fb4..00000000000 --- a/codegen/field_other.go +++ /dev/null @@ -1,94 +0,0 @@ -//go:build !go1.23 - -package codegen - -import ( - "fmt" - "go/types" - "reflect" -) - -func (b *builder) findBindFieldTarget(in types.Type, name string) (types.Object, error) { - switch t := in.(type) { - case *types.Named: - return b.findBindFieldTarget(t.Underlying(), name) - case *types.Struct: - var found types.Object - for i := 0; i < t.NumFields(); i++ { - field := t.Field(i) - if !field.Exported() || !equalFieldName(field.Name(), name) { - continue - } - - if found != nil { - return nil, fmt.Errorf("found more than one matching field to bind for %s", name) - } - - found = field - } - - return found, nil - } - - return nil, nil -} - -func (b *builder) findBindEmbedsTarget(in types.Type, name string) (types.Object, error) { - switch t := in.(type) { - case *types.Named: - return b.findBindEmbedsTarget(t.Underlying(), name) - case *types.Struct: - return b.findBindStructEmbedsTarget(t, name) - case *types.Interface: - return b.findBindInterfaceEmbedsTarget(t, name) - } - - return nil, nil -} - -func (b *builder) findBindStructTagTarget(in types.Type, name string) (types.Object, error) { - if b.Config.StructTag == "" { - return nil, nil - } - - switch t := in.(type) { - case *types.Named: - return b.findBindStructTagTarget(t.Underlying(), name) - case *types.Struct: - var found types.Object - for i := 0; i < t.NumFields(); i++ { - field := t.Field(i) - if !field.Exported() || field.Embedded() { - continue - } - tags := reflect.StructTag(t.Tag(i)) - if val, ok := tags.Lookup(b.Config.StructTag); ok && equalFieldName(val, name) { - if found != nil { - return nil, fmt.Errorf("tag %s is ambiguous; multiple fields have the same tag value of %s", b.Config.StructTag, val) - } - - found = field - } - } - - return found, nil - } - - return nil, nil -} - -func (b *builder) findBindMethodTarget(in types.Type, name string) (types.Object, error) { - switch t := in.(type) { - case *types.Named: - if _, ok := t.Underlying().(*types.Interface); ok { - return b.findBindMethodTarget(t.Underlying(), name) - } - - return b.findBindMethoderTarget(t.Method, t.NumMethods(), name) - case *types.Interface: - // FIX-ME: Should use ExplicitMethod here? What's the difference? - return b.findBindMethoderTarget(t.Method, t.NumMethods(), name) - } - - return nil, nil -} diff --git a/codegen/object.go b/codegen/object.go index 1b780bd0c1d..99d19bda67c 100644 --- a/codegen/object.go +++ b/codegen/object.go @@ -12,6 +12,7 @@ import ( "golang.org/x/text/language" "github.com/99designs/gqlgen/codegen/config" + "github.com/99designs/gqlgen/internal/code" ) type GoFieldType int @@ -116,8 +117,9 @@ func (o *Object) HasUnmarshal() bool { if o.IsMap() { return false } - for i := 0; i < o.Type.(*types.Named).NumMethods(); i++ { - if o.Type.(*types.Named).Method(i).Name() == "UnmarshalGQL" { + t := code.UnwrapAlias(o.Type) + for i := 0; i < t.(*types.Named).NumMethods(); i++ { + if t.(*types.Named).Method(i).Name() == "UnmarshalGQL" { return true } } diff --git a/codegen/templates/templates.go b/codegen/templates/templates.go index c89d912ff3d..b8b524138c4 100644 --- a/codegen/templates/templates.go +++ b/codegen/templates/templates.go @@ -275,6 +275,38 @@ func Call(p *types.Func) string { return pkg + p.Name() } +func TypeIdentifier(t types.Type) string { + t = code.UnwrapAlias(t) + res := "" + for { + switch it := t.(type) { + case *types.Pointer: + t.Underlying() + res += "ᚖ" + t = it.Elem() + case *types.Slice: + res += "ᚕ" + t = it.Elem() + case *types.Named: + res += pkgReplacer.Replace(it.Obj().Pkg().Path()) + res += "ᚐ" + res += it.Obj().Name() + return res + case *types.Basic: + res += it.Name() + return res + case *types.Map: + res += "map" + return res + case *types.Interface: + res += "interface" + return res + default: + panic(fmt.Errorf("unexpected type %T", it)) + } + } +} + func resetModelNames() { modelNamesMu.Lock() defer modelNamesMu.Unlock() diff --git a/codegen/templates/templates_1.23.go b/codegen/templates/templates_1.23.go deleted file mode 100644 index 3fd00a1ed68..00000000000 --- a/codegen/templates/templates_1.23.go +++ /dev/null @@ -1,41 +0,0 @@ -//go:build go1.23 - -package templates - -import ( - "fmt" - "go/types" -) - -func TypeIdentifier(t types.Type) string { - res := "" - for { - switch it := t.(type) { - case *types.Pointer: - t.Underlying() - res += "ᚖ" - t = it.Elem() - case *types.Slice: - res += "ᚕ" - t = it.Elem() - case *types.Named: - res += pkgReplacer.Replace(it.Obj().Pkg().Path()) - res += "ᚐ" - res += it.Obj().Name() - return res - case *types.Basic: - res += it.Name() - return res - case *types.Map: - res += "map" - return res - case *types.Interface: - res += "interface" - return res - case *types.Alias: - return TypeIdentifier(it.Rhs()) - default: - panic(fmt.Errorf("unexpected type %T", it)) - } - } -} diff --git a/codegen/templates/templates_other.go b/codegen/templates/templates_other.go deleted file mode 100644 index 3e4752931d8..00000000000 --- a/codegen/templates/templates_other.go +++ /dev/null @@ -1,39 +0,0 @@ -//go:build !go1.23 - -package templates - -import ( - "fmt" - "go/types" -) - -func TypeIdentifier(t types.Type) string { - res := "" - for { - switch it := t.(type) { - case *types.Pointer: - t.Underlying() - res += "ᚖ" - t = it.Elem() - case *types.Slice: - res += "ᚕ" - t = it.Elem() - case *types.Named: - res += pkgReplacer.Replace(it.Obj().Pkg().Path()) - res += "ᚐ" - res += it.Obj().Name() - return res - case *types.Basic: - res += it.Name() - return res - case *types.Map: - res += "map" - return res - case *types.Interface: - res += "interface" - return res - default: - panic(fmt.Errorf("unexpected type %T", it)) - } - } -} diff --git a/internal/code/alias.go b/internal/code/alias.go new file mode 100644 index 00000000000..cb1b1fd5ce7 --- /dev/null +++ b/internal/code/alias.go @@ -0,0 +1,12 @@ +//go:build !go1.23 + +package code + +import ( + "go/types" +) + +// UnwrapAlias unwraps an alias type +func UnwrapAlias(t types.Type) types.Type { + return t // No-op +} diff --git a/internal/code/alias_1.23.go b/internal/code/alias_1.23.go new file mode 100644 index 00000000000..23f3a71cfd7 --- /dev/null +++ b/internal/code/alias_1.23.go @@ -0,0 +1,15 @@ +//go:build go1.23 + +package code + +import ( + "go/types" +) + +// UnwrapAlias unwraps an alias type +func UnwrapAlias(t types.Type) types.Type { + if a, ok := t.(*types.Alias); ok { + return a.Rhs() + } + return t +}