Skip to content

Commit

Permalink
refactor: don't extract @gofield twice
Browse files Browse the repository at this point in the history
We already extract the values in config.Init(). Remove the duplicate logic in the modelgen plugin.

We leave the reference to GoFieldHook even though it's a noop since it's public. This makes this a non-breaking change. We will remove this during the next breaking release.
  • Loading branch information
clayne11 committed Aug 27, 2024
1 parent 3e76e7e commit 6f99792
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 118 deletions.
17 changes: 14 additions & 3 deletions codegen/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,24 +320,33 @@ func (c *Config) injectTypesFromSchema() error {
}
}

if schemaType.Kind == ast.Object || schemaType.Kind == ast.InputObject {
if schemaType.Kind == ast.Object ||
schemaType.Kind == ast.InputObject ||
schemaType.Kind == ast.Interface {
for _, field := range schemaType.Fields {
if fd := field.Directives.ForName("goField"); fd != nil {
forceResolver := c.Models[schemaType.Name].Fields[field.Name].Resolver
fieldName := c.Models[schemaType.Name].Fields[field.Name].FieldName

if ra := fd.Arguments.ForName("forceResolver"); ra != nil {
if fr, err := ra.Value.Value(nil); err == nil {
forceResolver = fr.(bool)
}
}

fieldName := c.Models[schemaType.Name].Fields[field.Name].FieldName
if na := fd.Arguments.ForName("name"); na != nil {
if fr, err := na.Value.Value(nil); err == nil {
fieldName = fr.(string)
}
}

omittable := c.Models[schemaType.Name].Fields[field.Name].Omittable
if arg := fd.Arguments.ForName("omittable"); arg != nil {
if k, err := arg.Value.Value(nil); err == nil {
val := k.(bool)
omittable = &val
}
}

if c.Models[schemaType.Name].Fields == nil {
c.Models[schemaType.Name] = TypeMapEntry{
Model: c.Models[schemaType.Name].Model,
Expand All @@ -349,6 +358,7 @@ func (c *Config) injectTypesFromSchema() error {
c.Models[schemaType.Name].Fields[field.Name] = TypeMapField{
FieldName: fieldName,
Resolver: forceResolver,
Omittable: omittable,
}
}
}
Expand Down Expand Up @@ -449,6 +459,7 @@ type TypeMapEntry struct {
type TypeMapField struct {
Resolver bool `yaml:"resolver"`
FieldName string `yaml:"fieldName"`
Omittable *bool `yaml:"omittable"`
GeneratedMethod string `yaml:"-"`
}

Expand Down
227 changes: 112 additions & 115 deletions plugin/modelgen/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,6 @@ type (

// DefaultFieldMutateHook is the default hook for the Plugin which applies the GoFieldHook and GoTagFieldHook.
func DefaultFieldMutateHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) {
var err error
f, err = GoFieldHook(td, fd, f)
if err != nil {
return f, err
}
return GoTagFieldHook(td, fd, f)
}

Expand Down Expand Up @@ -337,117 +332,139 @@ func (m *Plugin) generateFields(cfg *config.Config, schemaType *ast.Definition)
binder := cfg.NewBinder()
fields := make([]*Field, 0)

var omittableType types.Type

for _, field := range schemaType.Fields {
var typ types.Type
fieldDef := cfg.Schema.Types[field.Type.Name()]

if cfg.Models.UserDefined(field.Type.Name()) {
var err error
typ, err = binder.FindTypeFromName(cfg.Models[field.Type.Name()].Model[0])
if err != nil {
return nil, err
}
} else {
switch fieldDef.Kind {
case ast.Scalar:
// no user defined model, referencing a default scalar
typ = types.NewNamed(
types.NewTypeName(0, cfg.Model.Pkg(), "string", nil),
nil,
nil,
)

case ast.Interface, ast.Union:
// no user defined model, referencing a generated interface type
typ = types.NewNamed(
types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
types.NewInterfaceType([]*types.Func{}, []types.Type{}),
nil,
)

case ast.Enum:
// no user defined model, must reference a generated enum
typ = types.NewNamed(
types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
nil,
nil,
)

case ast.Object, ast.InputObject:
// no user defined model, must reference a generated struct
typ = types.NewNamed(
types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
types.NewStruct(nil, nil),
nil,
)

default:
panic(fmt.Errorf("unknown ast type %s", fieldDef.Kind))
}
f, err := m.generateField(cfg, binder, schemaType, field)
if err != nil {
return nil, err
}

name := templates.ToGo(field.Name)
if nameOverride := cfg.Models[schemaType.Name].Fields[field.Name].FieldName; nameOverride != "" {
name = nameOverride
if f == nil {
continue
}

typ = binder.CopyModifiersFromAst(field.Type, typ)
fields = append(fields, f)
}

if cfg.StructFieldsAlwaysPointers {
if isStruct(typ) && (fieldDef.Kind == ast.Object || fieldDef.Kind == ast.InputObject) {
typ = types.NewPointer(typ)
}
fields = append(fields, getExtraFields(cfg, schemaType.Name)...)

return fields, nil
}

func (m *Plugin) generateField(
cfg *config.Config,
binder *config.Binder,
schemaType *ast.Definition,
field *ast.FieldDefinition,
) (*Field, error) {
var omittableType types.Type
var typ types.Type
fieldDef := cfg.Schema.Types[field.Type.Name()]

if cfg.Models.UserDefined(field.Type.Name()) {
var err error
typ, err = binder.FindTypeFromName(cfg.Models[field.Type.Name()].Model[0])
if err != nil {
return nil, err
}
} else {
switch fieldDef.Kind {
case ast.Scalar:
// no user defined model, referencing a default scalar
typ = types.NewNamed(
types.NewTypeName(0, cfg.Model.Pkg(), "string", nil),
nil,
nil,
)

case ast.Interface, ast.Union:
// no user defined model, referencing a generated interface type
typ = types.NewNamed(
types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
types.NewInterfaceType([]*types.Func{}, []types.Type{}),
nil,
)

case ast.Enum:
// no user defined model, must reference a generated enum
typ = types.NewNamed(
types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
nil,
nil,
)

f := &Field{
Name: field.Name,
GoName: name,
Type: typ,
Description: field.Description,
Tag: getStructTagFromField(cfg, field),
Omittable: cfg.NullableInputOmittable && schemaType.Kind == ast.InputObject && !field.Type.NonNull,
case ast.Object, ast.InputObject:
// no user defined model, must reference a generated struct
typ = types.NewNamed(
types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
types.NewStruct(nil, nil),
nil,
)

default:
panic(fmt.Errorf("unknown ast type %s", fieldDef.Kind))
}
}

if m.FieldHook != nil {
mf, err := m.FieldHook(schemaType, field, f)
if err != nil {
return nil, fmt.Errorf("generror: field %v.%v: %w", schemaType.Name, field.Name, err)
}
f = mf
name := templates.ToGo(field.Name)
if nameOverride := cfg.Models[schemaType.Name].Fields[field.Name].FieldName; nameOverride != "" {
name = nameOverride
}

typ = binder.CopyModifiersFromAst(field.Type, typ)

if cfg.StructFieldsAlwaysPointers {
if isStruct(typ) && (fieldDef.Kind == ast.Object || fieldDef.Kind == ast.InputObject) {
typ = types.NewPointer(typ)
}
}

if f.IsResolver && cfg.OmitResolverFields {
continue
f := &Field{
Name: field.Name,
GoName: name,
Type: typ,
Description: field.Description,
Tag: getStructTagFromField(cfg, field),
Omittable: cfg.NullableInputOmittable && schemaType.Kind == ast.InputObject && !field.Type.NonNull,
IsResolver: cfg.Models[schemaType.Name].Fields[field.Name].Resolver,
}

if omittable := cfg.Models[schemaType.Name].Fields[field.Name].Omittable; omittable != nil {
f.Omittable = *omittable
}

if m.FieldHook != nil {
mf, err := m.FieldHook(schemaType, field, f)
if err != nil {
return nil, fmt.Errorf("generror: field %v.%v: %w", schemaType.Name, field.Name, err)
}
f = mf
}

if f.Omittable {
if schemaType.Kind != ast.InputObject || field.Type.NonNull {
return nil, fmt.Errorf("generror: field %v.%v: omittable is only applicable to nullable input fields", schemaType.Name, field.Name)
}
if f.IsResolver && cfg.OmitResolverFields {
return nil, nil
}

var err error
if f.Omittable {
if schemaType.Kind != ast.InputObject || field.Type.NonNull {
return nil, fmt.Errorf("generror: field %v.%v: omittable is only applicable to nullable input fields", schemaType.Name, field.Name)
}

if omittableType == nil {
omittableType, err = binder.FindTypeFromName("github.com/99designs/gqlgen/graphql.Omittable")
if err != nil {
return nil, err
}
}
var err error

f.Type, err = binder.InstantiateType(omittableType, []types.Type{f.Type})
if omittableType == nil {
omittableType, err = binder.FindTypeFromName("github.com/99designs/gqlgen/graphql.Omittable")
if err != nil {
return nil, fmt.Errorf("generror: field %v.%v: %w", schemaType.Name, field.Name, err)
return nil, err
}
}

fields = append(fields, f)
f.Type, err = binder.InstantiateType(omittableType, []types.Type{f.Type})
if err != nil {
return nil, fmt.Errorf("generror: field %v.%v: %w", schemaType.Name, field.Name, err)
}
}

fields = append(fields, getExtraFields(cfg, schemaType.Name)...)

return fields, nil
return f, nil
}

func getExtraFields(cfg *config.Config, modelName string) []*Field {
Expand Down Expand Up @@ -636,29 +653,9 @@ func removeDuplicateTags(t string) string {
return returnTags
}

// GoFieldHook applies the goField directive to the generated Field f.
// GoFieldHook is a noop
// TODO: This will be removed in the next breaking release
func GoFieldHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) {
args := make([]string, 0)
_ = args
for _, goField := range fd.Directives.ForNames("goField") {
if arg := goField.Arguments.ForName("name"); arg != nil {
if k, err := arg.Value.Value(nil); err == nil {
f.GoName = k.(string)
}
}

if arg := goField.Arguments.ForName("forceResolver"); arg != nil {
if k, err := arg.Value.Value(nil); err == nil {
f.IsResolver = k.(bool)
}
}

if arg := goField.Arguments.ForName("omittable"); arg != nil {
if k, err := arg.Value.Value(nil); err == nil {
f.Omittable = k.(bool)
}
}
}
return f, nil
}

Expand Down

0 comments on commit 6f99792

Please sign in to comment.