From 395861944411ba1aee74ca90a0f3943dc0a512ae Mon Sep 17 00:00:00 2001 From: Fabian Simon Date: Thu, 9 Mar 2023 09:51:46 +0100 Subject: [PATCH] fix: complex filter queries wich show more than one time to same table only one join will executed --- generate_code_filter_model.go.tpl | 58 +++++++++++++++++-------------- generate_code_resolver.go.tpl | 12 ++++--- 2 files changed, 40 insertions(+), 30 deletions(-) diff --git a/generate_code_filter_model.go.tpl b/generate_code_filter_model.go.tpl index d8ebf00..04375db 100644 --- a/generate_code_filter_model.go.tpl +++ b/generate_code_filter_model.go.tpl @@ -27,12 +27,12 @@ func (d *{{$object.Name}}FiltersInput) PrimaryKeyName() string { } -func (d *{{$object.Name}}FiltersInput) {{$methodeName}}(db *gorm.DB, alias string,deep bool) []runtimehelper.ConditionElement { +func (d *{{$object.Name}}FiltersInput) {{$methodeName}}(db *gorm.DB, alias string,deep bool, blackList map[string]struct{}) []runtimehelper.ConditionElement { res := make([]runtimehelper.ConditionElement, 0) if d.And != nil { tmp := make([]runtimehelper.ConditionElement, 0) for _, v := range d.And { - tmp = append(tmp, runtimehelper.Complex(runtimehelper.RelationAnd,v.ExtendsDatabaseQuery(db, alias, true)...)) + tmp = append(tmp, runtimehelper.Complex(runtimehelper.RelationAnd,v.ExtendsDatabaseQuery(db, alias, true,blackList)...)) } res = append(res, runtimehelper.Complex(runtimehelper.RelationAnd,tmp...)) } @@ -41,13 +41,13 @@ func (d *{{$object.Name}}FiltersInput) {{$methodeName}}(db *gorm.DB, alias strin tmp := make([]runtimehelper.ConditionElement, 0) for _, v := range d.Or { - tmp = append(tmp, runtimehelper.Complex(runtimehelper.RelationAnd, v.ExtendsDatabaseQuery(db, alias, true)...)) + tmp = append(tmp, runtimehelper.Complex(runtimehelper.RelationAnd, v.ExtendsDatabaseQuery(db, alias, true,blackList)...)) } res = append(res, runtimehelper.Complex(runtimehelper.RelationOr,tmp...)) } if d.Not != nil { - res = append(res, runtimehelper.Complex(runtimehelper.RelationNot,d.Not.ExtendsDatabaseQuery(db, alias, true)...)) + res = append(res, runtimehelper.Complex(runtimehelper.RelationNot,d.Not.ExtendsDatabaseQuery(db, alias, true,blackList)...)) } {{- range $entityKey, $entity := $object.InputFilterEntities }} {{- $entityGoName := $root.GetGoFieldName $objectName $entity}} @@ -55,24 +55,30 @@ func (d *{{$object.Name}}FiltersInput) {{$methodeName}}(db *gorm.DB, alias strin {{- if or $entity.IsPrimitive $entity.GqlTypeObj.HasSqlDirective }} if d.{{$entityGoName}} != nil { {{- if $entity.IsPrimitive }} - res = append(res, d.{{$entityGoName}}.{{$methodeName}}(db, fmt.Sprintf("%s.%s",alias,"{{snakecase $entityGoName}}"),true)...) + res = append(res, d.{{$entityGoName}}.{{$methodeName}}(db, fmt.Sprintf("%s.%s",alias,"{{snakecase $entityGoName}}"),true,blackList)...) {{- else }} {{- if $entity.HasMany2ManyDirective}} tableName := db.Config.NamingStrategy.TableName("{{$root.GetGoFieldTypeName $objectName $entity }}") - {{- $m2mTableName := $entity.Many2ManyDirectiveTable}} - db := db.Joins(fmt.Sprintf("JOIN {{$m2mTableName}} ON {{$m2mTableName}}.{{$object.Name | snakecase}}_{{$root.PrimaryKeyOfObject $object.Name}} = %s.{{$root.PrimaryKeyOfObject $object.Name}} JOIN %s ON {{$m2mTableName}}.{{$entity.GqlTypeName | snakecase}}_{{$root.PrimaryKeyOfObject $entity.GqlTypeName | snakecase}} = %s.{{$root.PrimaryKeyOfObject $object.Name}}", alias, tableName,tableName)) - res = append(res, d.{{$entityGoName}}.{{$methodeName}}(db, tableName,true)...) + {{- $m2mTableName := $entity.Many2ManyDirectiveTable}} + if _, ok := blackList["{{$m2mTableName}}"]; !ok { + blackList["{{$m2mTableName}}"] = struct{}{} + db = db.Joins(fmt.Sprintf("JOIN {{$m2mTableName}} ON {{$m2mTableName}}.{{$object.Name | snakecase}}_{{$root.PrimaryKeyOfObject $object.Name}} = %s.{{$root.PrimaryKeyOfObject $object.Name}} JOIN %s ON {{$m2mTableName}}.{{$entity.GqlTypeName | snakecase}}_{{$root.PrimaryKeyOfObject $entity.GqlTypeName | snakecase}} = %s.{{$root.PrimaryKeyOfObject $object.Name}}", alias, tableName,tableName)) + } + res = append(res, d.{{$entityGoName}}.{{$methodeName}}(db, tableName,true,blackList)...) {{- else if eq $object.Name $entity.GqlTypeName}} - res = append(res, d.{{$entityGoName}}.{{$methodeName}}(db, "{{$entityGoName}}",true)...) + res = append(res, d.{{$entityGoName}}.{{$methodeName}}(db, "{{$entityGoName}}",true,blackList)...) {{- else }} - if deep { - tableName := db.Config.NamingStrategy.TableName("{{$root.GetGoFieldTypeName $objectName $entity }}") - foreignKeyName := "{{$root.ForeignName $object $entity | snakecase}}" - db = db.Joins(fmt.Sprintf("JOIN %s {{$entityGoName}} ON {{$entityGoName}}.%s = %s.%s",tableName, foreignKeyName, alias, d.PrimaryKeyName())) - }else { - db = db.Joins("{{$entityGoName}}") + if _, ok := blackList["{{$entityGoName}}"]; !ok { + blackList["{{$entityGoName}}"] = struct{}{} + if deep { + tableName := db.Config.NamingStrategy.TableName("{{$root.GetGoFieldTypeName $objectName $entity }}") + foreignKeyName := "{{$root.ForeignName $object $entity | snakecase}}" + db = db.Joins(fmt.Sprintf("JOIN %s {{$entityGoName}} ON {{$entityGoName}}.%s = %s.%s",tableName, foreignKeyName, alias, d.PrimaryKeyName())) + }else { + db = db.Joins("{{$entityGoName}}") + } } - res = append(res, d.{{$entityGoName}}.{{$methodeName}}(db, "{{$entityGoName}}",true)...) + res = append(res, d.{{$entityGoName}}.{{$methodeName}}(db, "{{$entityGoName}}",true,blackList)...) {{- end}} {{- end}} } @@ -84,7 +90,7 @@ func (d *{{$object.Name}}FiltersInput) {{$methodeName}}(db *gorm.DB, alias strin {{- end}} {{- end}} -func (d *StringFilterInput) ExtendsDatabaseQuery(db *gorm.DB, fieldName string, deep bool) []runtimehelper.ConditionElement { +func (d *StringFilterInput) ExtendsDatabaseQuery(db *gorm.DB, fieldName string, deep bool, blackList map[string]struct{}) []runtimehelper.ConditionElement { res := make([]runtimehelper.ConditionElement, 0) if d.And != nil { tmp := make([]runtimehelper.ConditionElement, 0) @@ -122,7 +128,7 @@ func (d *StringFilterInput) ExtendsDatabaseQuery(db *gorm.DB, fieldName string, } if d.Not != nil { - res = append(res, runtimehelper.Complex(runtimehelper.RelationNot,d.Not.ExtendsDatabaseQuery(db,fieldName, true)...)) + res = append(res, runtimehelper.Complex(runtimehelper.RelationNot,d.Not.ExtendsDatabaseQuery(db,fieldName, true, blackList)...)) } if d.NotContains != nil { @@ -160,7 +166,7 @@ func (d *StringFilterInput) ExtendsDatabaseQuery(db *gorm.DB, fieldName string, return res } -func (d *IntFilterInput) ExtendsDatabaseQuery(db *gorm.DB, fieldName string, deep bool) []runtimehelper.ConditionElement { +func (d *IntFilterInput) ExtendsDatabaseQuery(db *gorm.DB, fieldName string, deep bool, blackList map[string]struct{}) []runtimehelper.ConditionElement { res := make([]runtimehelper.ConditionElement, 0) @@ -203,7 +209,7 @@ func (d *IntFilterInput) ExtendsDatabaseQuery(db *gorm.DB, fieldName string, dee res = append(res, runtimehelper.NotEqual(fieldName,*d.Ne)) } if d.Not != nil { - res = append(res, runtimehelper.Complex(runtimehelper.RelationNot,d.Not.ExtendsDatabaseQuery(db,fieldName, true)...)) + res = append(res, runtimehelper.Complex(runtimehelper.RelationNot,d.Not.ExtendsDatabaseQuery(db,fieldName, true,blackList)...)) } if d.NotIn != nil { @@ -230,7 +236,7 @@ func (d *IntFilterInput) ExtendsDatabaseQuery(db *gorm.DB, fieldName string, dee return res } -func (d *BooleanFilterInput) ExtendsDatabaseQuery(db *gorm.DB, fieldName string, deep bool) []runtimehelper.ConditionElement { +func (d *BooleanFilterInput) ExtendsDatabaseQuery(db *gorm.DB, fieldName string, deep bool, blackList map[string]struct{}) []runtimehelper.ConditionElement { res := make([]runtimehelper.ConditionElement, 0) if d.And != nil { @@ -246,7 +252,7 @@ func (d *BooleanFilterInput) ExtendsDatabaseQuery(db *gorm.DB, fieldName string, } if d.Not != nil { - res = append(res, runtimehelper.Complex(runtimehelper.RelationNot, d.Not.ExtendsDatabaseQuery(db, fieldName, true)...)) + res = append(res, runtimehelper.Complex(runtimehelper.RelationNot, d.Not.ExtendsDatabaseQuery(db, fieldName, true,blackList)...)) } if d.NotNull != nil && *d.NotNull { @@ -268,7 +274,7 @@ func (d *BooleanFilterInput) ExtendsDatabaseQuery(db *gorm.DB, fieldName string, return res } -func (d *TimeFilterInput) ExtendsDatabaseQuery(db *gorm.DB, fieldName string, deep bool) []runtimehelper.ConditionElement { +func (d *TimeFilterInput) ExtendsDatabaseQuery(db *gorm.DB, fieldName string, deep bool, blackList map[string]struct{}) []runtimehelper.ConditionElement { res := make([]runtimehelper.ConditionElement, 0) @@ -311,7 +317,7 @@ func (d *TimeFilterInput) ExtendsDatabaseQuery(db *gorm.DB, fieldName string, de res = append(res, runtimehelper.NotEqual(fieldName,*d.Ne)) } if d.Not != nil { - res = append(res, runtimehelper.Complex(runtimehelper.RelationNot,d.Not.ExtendsDatabaseQuery(db,fieldName, true)...)) + res = append(res, runtimehelper.Complex(runtimehelper.RelationNot,d.Not.ExtendsDatabaseQuery(db,fieldName, true,blackList)...)) } if d.NotIn != nil { @@ -338,7 +344,7 @@ func (d *TimeFilterInput) ExtendsDatabaseQuery(db *gorm.DB, fieldName string, de return res } -func (d *IDFilterInput) ExtendsDatabaseQuery(db *gorm.DB, fieldName string, deep bool) []runtimehelper.ConditionElement { +func (d *IDFilterInput) ExtendsDatabaseQuery(db *gorm.DB, fieldName string, deep bool, blackList map[string]struct{}) []runtimehelper.ConditionElement { res := make([]runtimehelper.ConditionElement, 0) @@ -364,7 +370,7 @@ func (d *IDFilterInput) ExtendsDatabaseQuery(db *gorm.DB, fieldName string, deep res = append(res, runtimehelper.NotEqual(fieldName, *d.Ne)) } if d.Not != nil { - res = append(res, runtimehelper.Complex(runtimehelper.RelationNot, d.Not.ExtendsDatabaseQuery(db, fieldName, true)...)) + res = append(res, runtimehelper.Complex(runtimehelper.RelationNot, d.Not.ExtendsDatabaseQuery(db, fieldName, true,blackList)...)) } diff --git a/generate_code_resolver.go.tpl b/generate_code_resolver.go.tpl index 0293ade..6577cc5 100644 --- a/generate_code_resolver.go.tpl +++ b/generate_code_resolver.go.tpl @@ -66,7 +66,8 @@ func (r *queryResolver) Query{{$object.Name}}(ctx context.Context, filter *model tableName := r.Sql.Db.Config.NamingStrategy.TableName("{{$object.Name}}") db = runtimehelper.GetPreloadSelection(ctx, db,runtimehelper.GetPreloadsMap(ctx, "data").SubTables[0]) if filter != nil{ - sql, arguments := runtimehelper.CombineSimpleQuery(filter.ExtendsDatabaseQuery(db, tableName, false), "AND") + blackList := make(map[string]struct{}) + sql, arguments := runtimehelper.CombineSimpleQuery(filter.ExtendsDatabaseQuery(db, tableName, false, blackList), "AND") db.Where(sql, arguments...) } @@ -133,7 +134,8 @@ func (r *{{lcFirst $object.Name}}PayloadResolver[T]) {{$object.Name}}(ctx contex {{- range $m2mKey, $m2mEntity := $object.Many2ManyRefEntities }} func (r *mutationResolver) Add{{$m2mEntity.GqlTypeName}}2{{$object.Name}}s(ctx context.Context, input model.{{$m2mEntity.GqlTypeName}}Ref2{{$object.Name}}sInput) (*model.Update{{$object.Name}}Payload, error){ tableName := r.Sql.Db.Config.NamingStrategy.TableName("{{$object.Name}}") - sql, arguments := runtimehelper.CombineSimpleQuery(input.Filter.ExtendsDatabaseQuery(r.Sql.Db, tableName, false), "AND") + blackList := make(map[string]struct{}) + sql, arguments := runtimehelper.CombineSimpleQuery(input.Filter.ExtendsDatabaseQuery(r.Sql.Db, tableName, false, blackList), "AND") db := r.Sql.Db.Model(&model.{{$object.Name}}{}).Where(sql, arguments...) var res []*model.{{$object.Name}} db.Find(&res) @@ -208,7 +210,8 @@ func (r *mutationResolver) Update{{$object.Name}}(ctx context.Context, input mod } } tableName := r.Sql.Db.Config.NamingStrategy.TableName("{{$object.Name}}") - sql, arguments := runtimehelper.CombineSimpleQuery(input.Filter.ExtendsDatabaseQuery(r.Sql.Db, tableName, false), "AND") + blackList := make(map[string]struct{}) + sql, arguments := runtimehelper.CombineSimpleQuery(input.Filter.ExtendsDatabaseQuery(r.Sql.Db, tableName, false, blackList), "AND") obj := model.{{$object.Name}}{} db = db.Model(&obj).Where(sql, arguments...) update := input.Set.MergeToType() @@ -246,7 +249,8 @@ func (r *mutationResolver) Delete{{$object.Name}}(ctx context.Context, filter mo } } tableName := r.Sql.Db.Config.NamingStrategy.TableName("{{$object.Name}}") - sql, arguments := runtimehelper.CombineSimpleQuery(filter.ExtendsDatabaseQuery(db, tableName, false), "AND") + blackList := make(map[string]struct{}) + sql, arguments := runtimehelper.CombineSimpleQuery(filter.ExtendsDatabaseQuery(db, tableName, false, blackList), "AND") obj := model.{{$object.Name}}{} db = db.Where(sql, arguments...) if okHook {