From f1bacd72c2a40849cad2c6307612db7b69a54d8f Mon Sep 17 00:00:00 2001 From: miyamo2 <79917704+miyamo2@users.noreply.github.com> Date: Sun, 23 Jun 2024 00:54:52 +0900 Subject: [PATCH] refactor: merge some `.go` files (#131) --- callbacks_registerer.go | 12 - db_opener.go | 22 -- dynmgrm.go | 38 +++ dynmgrm_test.go | 48 ++++ error_translator.go | 22 -- error_translator_test.go | 56 ----- errors.go | 9 - list_type.go | 113 --------- map_type.go | 112 --------- migrator.go | 3 + set_type.go | 228 ------------------ typed_list_type.go | 75 ------ types.go | 497 +++++++++++++++++++++++++++++++++++++++ 13 files changed, 586 insertions(+), 649 deletions(-) delete mode 100644 callbacks_registerer.go delete mode 100644 db_opener.go delete mode 100644 error_translator.go delete mode 100644 error_translator_test.go delete mode 100644 errors.go delete mode 100644 list_type.go delete mode 100644 map_type.go delete mode 100644 set_type.go delete mode 100644 typed_list_type.go create mode 100644 types.go diff --git a/callbacks_registerer.go b/callbacks_registerer.go deleted file mode 100644 index 5eab50f..0000000 --- a/callbacks_registerer.go +++ /dev/null @@ -1,12 +0,0 @@ -package dynmgrm - -import ( - "gorm.io/gorm" - "gorm.io/gorm/callbacks" -) - -type callbacksRegisterer struct{} - -func (c *callbacksRegisterer) Register(db *gorm.DB, config *callbacks.Config) { - callbacks.RegisterDefaultCallbacks(db, config) -} diff --git a/db_opener.go b/db_opener.go deleted file mode 100644 index bc449d0..0000000 --- a/db_opener.go +++ /dev/null @@ -1,22 +0,0 @@ -package dynmgrm - -import ( - "database/sql" -) - -type dbOpener struct { - dsn string - driverName string -} - -func (o dbOpener) DSN() string { - return o.dsn -} - -func (o dbOpener) DriverName() string { - return o.driverName -} - -func (o dbOpener) Apply() (*sql.DB, error) { - return sql.Open(o.DriverName(), o.DSN()) -} diff --git a/dynmgrm.go b/dynmgrm.go index 3dcb0ee..c289e56 100644 --- a/dynmgrm.go +++ b/dynmgrm.go @@ -5,7 +5,9 @@ package dynmgrm import ( "database/sql" + "errors" "fmt" + "github.com/btnguyen2k/godynamo" "gorm.io/gorm/migrator" "strconv" "strings" @@ -278,3 +280,39 @@ func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { }, } } + +// Translate it will translate the error to native gorm errors. +func (dialector Dialector) Translate(err error) error { + switch { + case errors.Is(err, godynamo.ErrTxCommitting), + errors.Is(err, godynamo.ErrTxRollingBack), + errors.Is(err, godynamo.ErrInTx), + errors.Is(err, godynamo.ErrInvalidTxStage), + errors.Is(err, godynamo.ErrNoTx): + return gorm.ErrInvalidTransaction + } + return err +} + +type dbOpener struct { + dsn string + driverName string +} + +func (o dbOpener) DSN() string { + return o.dsn +} + +func (o dbOpener) DriverName() string { + return o.driverName +} + +func (o dbOpener) Apply() (*sql.DB, error) { + return sql.Open(o.DriverName(), o.DSN()) +} + +type callbacksRegisterer struct{} + +func (c *callbacksRegisterer) Register(db *gorm.DB, config *callbacks.Config) { + callbacks.RegisterDefaultCallbacks(db, config) +} diff --git a/dynmgrm_test.go b/dynmgrm_test.go index 71d1874..adfecc6 100644 --- a/dynmgrm_test.go +++ b/dynmgrm_test.go @@ -3,6 +3,7 @@ package dynmgrm import ( "database/sql" "errors" + "github.com/btnguyen2k/godynamo" "github.com/miyamo2/dynmgrm/internal/mocks" "go.uber.org/mock/gomock" "reflect" @@ -484,3 +485,50 @@ func TestDialector_Initialize(t *testing.T) { }) } } + +func TestDialector_Translate(t *testing.T) { + type test struct { + args error + want error + } + errOther := errors.New("other") + tests := map[string]test{ + "happy_path/ErrTxCommitting": { + args: godynamo.ErrInTx, + want: gorm.ErrInvalidTransaction, + }, + "happy_path/ErrTxRollingBack": { + args: godynamo.ErrTxRollingBack, + want: gorm.ErrInvalidTransaction, + }, + "happy_path/ErrInTx": { + args: godynamo.ErrInTx, + want: gorm.ErrInvalidTransaction, + }, + "happy_path/ErrInvalidTxStage": { + args: godynamo.ErrInvalidTxStage, + want: gorm.ErrInvalidTransaction, + }, + "happy_path/ErrNoTx": { + args: godynamo.ErrNoTx, + want: gorm.ErrInvalidTransaction, + }, + "happy_path/other": { + args: errOther, + want: errOther, + }, + "happy_path/nil": { + args: nil, + want: nil, + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + dialector := Dialector{} + err := dialector.Translate(tt.args) + if !errors.Is(err, tt.want) { + t.Errorf("Translate() error = %v, want %v", err, tt.want) + } + }) + } +} diff --git a/error_translator.go b/error_translator.go deleted file mode 100644 index afea188..0000000 --- a/error_translator.go +++ /dev/null @@ -1,22 +0,0 @@ -package dynmgrm - -import ( - "errors" - - "github.com/btnguyen2k/godynamo" - - "gorm.io/gorm" -) - -// Translate it will translate the error to native gorm errors. -func (dialector Dialector) Translate(err error) error { - switch { - case errors.Is(err, godynamo.ErrTxCommitting), - errors.Is(err, godynamo.ErrTxRollingBack), - errors.Is(err, godynamo.ErrInTx), - errors.Is(err, godynamo.ErrInvalidTxStage), - errors.Is(err, godynamo.ErrNoTx): - return gorm.ErrInvalidTransaction - } - return err -} diff --git a/error_translator_test.go b/error_translator_test.go deleted file mode 100644 index 09a8753..0000000 --- a/error_translator_test.go +++ /dev/null @@ -1,56 +0,0 @@ -package dynmgrm - -import ( - "errors" - "testing" - - "github.com/btnguyen2k/godynamo" - "gorm.io/gorm" -) - -func TestDialector_Translate(t *testing.T) { - type test struct { - args error - want error - } - errOther := errors.New("other") - tests := map[string]test{ - "happy_path/ErrTxCommitting": { - args: godynamo.ErrInTx, - want: gorm.ErrInvalidTransaction, - }, - "happy_path/ErrTxRollingBack": { - args: godynamo.ErrTxRollingBack, - want: gorm.ErrInvalidTransaction, - }, - "happy_path/ErrInTx": { - args: godynamo.ErrInTx, - want: gorm.ErrInvalidTransaction, - }, - "happy_path/ErrInvalidTxStage": { - args: godynamo.ErrInvalidTxStage, - want: gorm.ErrInvalidTransaction, - }, - "happy_path/ErrNoTx": { - args: godynamo.ErrNoTx, - want: gorm.ErrInvalidTransaction, - }, - "happy_path/other": { - args: errOther, - want: errOther, - }, - "happy_path/nil": { - args: nil, - want: nil, - }, - } - for name, tt := range tests { - t.Run(name, func(t *testing.T) { - dialector := Dialector{} - err := dialector.Translate(tt.args) - if !errors.Is(err, tt.want) { - t.Errorf("Translate() error = %v, want %v", err, tt.want) - } - }) - } -} diff --git a/errors.go b/errors.go deleted file mode 100644 index bc36ead..0000000 --- a/errors.go +++ /dev/null @@ -1,9 +0,0 @@ -package dynmgrm - -import "errors" - -var ( - ErrCollectionAlreadyContainsItem = errors.New("collection already contains item") - ErrFailedToCast = errors.New("failed to cast") - ErrDynmgrmAreNotSupported = errors.New("dynmgrm are not supported this operation") -) diff --git a/list_type.go b/list_type.go deleted file mode 100644 index fa0d1b4..0000000 --- a/list_type.go +++ /dev/null @@ -1,113 +0,0 @@ -package dynmgrm - -import ( - "context" - "database/sql" - "errors" - "fmt" - - "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" - "gorm.io/gorm" - "gorm.io/gorm/clause" -) - -// compatibility check -var ( - _ gorm.Valuer = (*List)(nil) - _ sql.Scanner = (*List)(nil) -) - -// List is a DynamoDB list type. -// -// See: https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/HowItWorks.NamingRulesDataTypes.html -type List []interface{} - -// GormDataType returns the data type for Gorm. -func (l *List) GormDataType() string { - return "dglist" -} - -// Scan implements the [sql.Scanner#Scan] -// -// [sql.Scanner#Scan]: https://golang.org/pkg/database/sql/#Scanner -func (l *List) Scan(value interface{}) error { - if len(*l) != 0 { - return ErrCollectionAlreadyContainsItem - } - sv, ok := value.([]interface{}) - if !ok { - return errors.Join(ErrFailedToCast, fmt.Errorf("incompatible %T and %T", l, value)) - } - *l = sv - return resolveCollectionsNestedInList(l) -} - -// GormValue implements the [gorm.Valuer] interface. -// -// [gorm.Valuer]: https://pkg.go.dev/gorm.io/gorm#Valuer -func (l List) GormValue(_ context.Context, db *gorm.DB) clause.Expr { - if err := resolveCollectionsNestedInList(&l); err != nil { - _ = db.AddError(err) - return clause.Expr{} - } - av, err := toDocumentAttributeValue[*types.AttributeValueMemberL](l) - if err != nil { - _ = db.AddError(err) - return clause.Expr{} - } - return clause.Expr{SQL: "?", Vars: []interface{}{*av}} -} - -// resolveCollectionsNestedInList resolves nested collection type attribute. -func resolveCollectionsNestedInList(l *List) error { - for i, v := range *l { - if v, ok := v.(map[string]interface{}); ok { - m := Map{} - err := m.Scan(v) - if err != nil { - *l = nil - return err - } - (*l)[i] = m - continue - } - if isCompatibleWithSet[int](v) { - s := newSet[int]() - if err := s.Scan(v); err == nil { - (*l)[i] = s - continue - } - } - if isCompatibleWithSet[float64](v) { - s := newSet[float64]() - if err := s.Scan(v); err == nil { - (*l)[i] = s - continue - } - } - if isCompatibleWithSet[string](v) { - s := newSet[string]() - if err := s.Scan(v); err == nil { - (*l)[i] = s - continue - } - } - if isCompatibleWithSet[[]byte](v) { - s := newSet[[]byte]() - if err := s.Scan(v); err == nil { - (*l)[i] = s - continue - } - } - if v, ok := v.([]interface{}); ok { - il := List{} - err := il.Scan(v) - if err != nil { - *l = nil - return err - } - (*l)[i] = il - } - } - return nil -} diff --git a/map_type.go b/map_type.go deleted file mode 100644 index 8bcb00e..0000000 --- a/map_type.go +++ /dev/null @@ -1,112 +0,0 @@ -package dynmgrm - -import ( - "context" - "database/sql" - - "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" - "gorm.io/gorm" - "gorm.io/gorm/clause" -) - -// compatibility check -var ( - _ gorm.Valuer = (*Map)(nil) - _ sql.Scanner = (*Map)(nil) -) - -// Map is a DynamoDB map type. -// -// See: https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/HowItWorks.NamingRulesDataTypes.html -type Map map[string]interface{} - -// GormDataType returns the data type for Gorm. -func (m Map) GormDataType() string { - return "dgmap" -} - -// Scan implements the [sql.Scanner#Scan] -// -// [sql.Scanner#Scan]: https://golang.org/pkg/database/sql/#Scanner -func (m *Map) Scan(value interface{}) error { - if len(*m) != 0 { - return ErrCollectionAlreadyContainsItem - } - mv, ok := value.(map[string]interface{}) - if !ok { - *m = nil - return ErrFailedToCast - } - *m = mv - return resolveCollectionsNestedInMap(m) -} - -// GormValue implements the [gorm.Valuer] interface. -// -// [gorm.Valuer]: https://pkg.go.dev/gorm.io/gorm#Valuer -func (m Map) GormValue(_ context.Context, db *gorm.DB) clause.Expr { - if err := resolveCollectionsNestedInMap(&m); err != nil { - _ = db.AddError(err) - return clause.Expr{} - } - av, err := toDocumentAttributeValue[*types.AttributeValueMemberM](m) - if err != nil { - _ = db.AddError(err) - return clause.Expr{} - } - return clause.Expr{SQL: "?", Vars: []interface{}{*av}} -} - -// resolveCollectionsNestedInMap resolves nested document type attribute. -func resolveCollectionsNestedInMap(m *Map) error { - for k, v := range *m { - if v, ok := v.(map[string]interface{}); ok { - im := Map{} - err := im.Scan(v) - if err != nil { - *m = nil - return err - } - (*m)[k] = im - continue - } - if isCompatibleWithSet[int](v) { - s := newSet[int]() - if err := s.Scan(v); err == nil { - (*m)[k] = s - continue - } - } - if isCompatibleWithSet[float64](v) { - s := newSet[float64]() - if err := s.Scan(v); err == nil { - (*m)[k] = s - continue - } - } - if isCompatibleWithSet[string](v) { - s := newSet[string]() - if err := s.Scan(v); err == nil { - (*m)[k] = s - continue - } - } - if isCompatibleWithSet[[]byte](v) { - s := newSet[[]byte]() - if err := s.Scan(v); err == nil { - (*m)[k] = s - continue - } - } - if v, ok := v.([]interface{}); ok { - l := List{} - err := l.Scan(v) - if err != nil { - *m = nil - return err - } - (*m)[k] = l - } - } - return nil -} diff --git a/migrator.go b/migrator.go index 8eb8ba8..2efad6d 100644 --- a/migrator.go +++ b/migrator.go @@ -3,6 +3,7 @@ package dynmgrm import ( + "errors" "fmt" "gorm.io/gorm" "gorm.io/gorm/clause" @@ -12,6 +13,8 @@ import ( "strings" ) +var ErrDynmgrmAreNotSupported = errors.New("dynmgrm are not supported this operation") + // CapacityUnitsSpecifier could specify WCUs and RCU type CapacityUnitsSpecifier interface { WCU() int diff --git a/set_type.go b/set_type.go deleted file mode 100644 index 31614e9..0000000 --- a/set_type.go +++ /dev/null @@ -1,228 +0,0 @@ -package dynmgrm - -import ( - "context" - "database/sql" - "errors" - "math" - - "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" - "gorm.io/gorm" - "gorm.io/gorm/clause" -) - -var ( - ErrValueIsIncompatibleOfStringSlice = errors.New("value is incompatible of string slice") - ErrValueIsIncompatibleOfIntSlice = errors.New("value is incompatible of int slice") - ErrValueIsIncompatibleOfFloat64Slice = errors.New("value is incompatible of float64 slice") - ErrValueIsIncompatibleOfBinarySlice = errors.New("value is incompatible of []byte slice") -) - -// SetSupportable are the types that support the Set -type SetSupportable interface { - string | []byte | int | float64 -} - -// compatibility check -var ( - _ gorm.Valuer = (*Set[string])(nil) - _ sql.Scanner = (*Set[string])(nil) -) - -// Set is a DynamoDB set type. -// -// See: https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/HowItWorks.NamingRulesDataTypes.html -type Set[T SetSupportable] []T - -// GormDataType returns the data type for Gorm. -func (s *Set[T]) GormDataType() string { - return "dgsets" -} - -// Scan implements the [sql.Scanner#Scan] -// -// [sql.Scanner#Scan]: https://golang.org/pkg/database/sql/#Scanner -func (s *Set[T]) Scan(value interface{}) error { - if len(*s) != 0 { - return ErrCollectionAlreadyContainsItem - } - if value == nil { - *s = nil - return nil - } - switch (interface{})(s).(type) { - case *Set[int]: - return scanAsIntSet((interface{})(s).(*Set[int]), value) - case *Set[float64]: - return scanAsFloat64Set((interface{})(s).(*Set[float64]), value) - case *Set[string]: - return scanAsStringSet((interface{})(s).(*Set[string]), value) - case *Set[[]byte]: - return scanAsBinarySet((interface{})(s).(*Set[[]byte]), value) - } - return nil -} - -// GormValue implements the [gorm.Valuer] interface. -// -// [gorm.Valuer]: https://pkg.go.dev/gorm.io/gorm#Valuer -func (s Set[T]) GormValue(_ context.Context, db *gorm.DB) clause.Expr { - switch s := (interface{})(s).(type) { - case Set[int]: - av, err := numericSetToAttributeValue(s) - if err != nil { - _ = db.AddError(err) - break - } - return clause.Expr{SQL: "?", Vars: []interface{}{*av}} - case Set[float64]: - av, err := numericSetToAttributeValue(s) - if err != nil { - _ = db.AddError(err) - break - } - return clause.Expr{SQL: "?", Vars: []interface{}{*av}} - case Set[string]: - av, err := stringSetToAttributeValue(s) - if err != nil { - _ = db.AddError(err) - break - } - return clause.Expr{SQL: "?", Vars: []interface{}{*av}} - case Set[[]byte]: - av, err := binarySetToAttributeValue(s) - if err != nil { - _ = db.AddError(err) - break - } - return clause.Expr{SQL: "?", Vars: []interface{}{*av}} - } - return clause.Expr{} -} - -func numericSetToAttributeValue[T Set[int] | Set[float64]](s T) (*types.AttributeValueMemberNS, error) { - return toDocumentAttributeValue[*types.AttributeValueMemberNS](s) -} - -func stringSetToAttributeValue(s Set[string]) (*types.AttributeValueMemberSS, error) { - return toDocumentAttributeValue[*types.AttributeValueMemberSS](s) -} - -func binarySetToAttributeValue(s Set[[]byte]) (*types.AttributeValueMemberBS, error) { - return toDocumentAttributeValue[*types.AttributeValueMemberBS](s) -} - -// scanAsIntSet scans the value as Set[int] -func scanAsIntSet(s *Set[int], value interface{}) error { - sv, ok := value.([]float64) - if !ok { - *s = nil - return ErrValueIsIncompatibleOfIntSlice - } - for _, v := range sv { - if math.Floor(v) != v { - *s = nil - return ErrValueIsIncompatibleOfIntSlice - } - *s = append(*s, int(v)) - } - return nil -} - -// scanAsFloat64Set scans the value as Set[float64] -func scanAsFloat64Set(s *Set[float64], value interface{}) error { - sv, ok := value.([]float64) - if !ok { - *s = nil - return ErrValueIsIncompatibleOfFloat64Slice - } - for _, v := range sv { - *s = append(*s, v) - } - return nil -} - -// scanAsStringSet scans the value as Set[string] -func scanAsStringSet(s *Set[string], value interface{}) error { - sv, ok := value.([]string) - if !ok { - *s = nil - return ErrValueIsIncompatibleOfStringSlice - } - for _, v := range sv { - *s = append(*s, v) - } - return nil -} - -// scanAsBinarySet scans the value as Set[[]byte] -func scanAsBinarySet(s *Set[[]byte], value interface{}) error { - sv, ok := value.([][]byte) - if !ok { - *s = nil - return ErrValueIsIncompatibleOfBinarySlice - } - for _, v := range sv { - *s = append(*s, v) - } - return nil -} - -func isCompatibleWithSet[T SetSupportable](value interface{}) (compatible bool) { - var t T - switch (interface{})(t).(type) { - case string: - compatible = isStringSetCompatible(value) - case int: - compatible = isIntSetCompatible(value) - case float64: - compatible = isFloat64SetCompatible(value) - case []byte: - compatible = isBinarySetCompatible(value) - } - return -} - -func isIntSetCompatible(value interface{}) (compatible bool) { - if _, ok := value.([]int); ok { - compatible = true - return - } - if value, ok := value.([]float64); ok { - compatible = true - for _, v := range value { - if math.Floor(v) == v { - compatible = true - continue - } - compatible = false - return - } - } - return -} - -func isStringSetCompatible(value interface{}) (compatible bool) { - if _, ok := value.([]string); ok { - compatible = true - } - return -} - -func isFloat64SetCompatible(value interface{}) (compatible bool) { - if _, ok := value.([]float64); ok { - compatible = true - } - return -} - -func isBinarySetCompatible(value interface{}) (compatible bool) { - if _, ok := value.([][]byte); ok { - compatible = true - } - return -} - -func newSet[T SetSupportable]() Set[T] { - return Set[T]{} -} diff --git a/typed_list_type.go b/typed_list_type.go deleted file mode 100644 index d7376f9..0000000 --- a/typed_list_type.go +++ /dev/null @@ -1,75 +0,0 @@ -package dynmgrm - -import ( - "context" - "database/sql" - "errors" - "fmt" - "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" - "gorm.io/gorm" - "gorm.io/gorm/clause" - "reflect" - "slices" -) - -// compatibility check -var ( - _ gorm.Valuer = (*TypedList[interface{}])(nil) - _ sql.Scanner = (*TypedList[interface{}])(nil) -) - -// TypedList is a DynamoDB list type with type specification. -// -// See: https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/HowItWorks.NamingRulesDataTypes.html -type TypedList[T any] []T - -// GormDataType returns the data type for Gorm. -func (l *TypedList[T]) GormDataType() string { - return "dgtypedlist" -} - -// Scan implements the [sql.Scanner#Scan] -// -// [sql.Scanner#Scan]: https://golang.org/pkg/database/sql/#Scanner -func (l *TypedList[T]) Scan(value interface{}) error { - if len(*l) != 0 { - return ErrCollectionAlreadyContainsItem - } - sv, ok := value.([]interface{}) - if !ok { - return errors.Join(ErrFailedToCast, fmt.Errorf("incompatible %T and %T", l, value)) - } - *l = slices.Grow(*l, len(sv)) - for _, v := range sv { - mv, ok := v.(map[string]interface{}) - if !ok { - var t T - return errors.Join(ErrFailedToCast, fmt.Errorf("incompatible %T and %T", t, v)) - } - dest := new(T) - rv := reflect.ValueOf(dest) - rt := reflect.TypeOf(*dest) - err := assignMapValueToReflectValue(rt, rv, mv) - if err != nil { - return err - } - *l = append(*l, *dest) - } - return nil -} - -// GormValue implements the [gorm.Valuer] interface. -// -// [gorm.Valuer]: https://pkg.go.dev/gorm.io/gorm#Valuer -func (l TypedList[T]) GormValue(_ context.Context, db *gorm.DB) clause.Expr { - avl := types.AttributeValueMemberL{Value: make([]types.AttributeValue, 0, len(l))} - for _, v := range l { - av, err := toDocumentAttributeValue[*types.AttributeValueMemberM](v) - if err != nil { - _ = db.AddError(err) - return clause.Expr{} - } - avl.Value = append(avl.Value, av) - } - return clause.Expr{SQL: "?", Vars: []interface{}{avl}} -} diff --git a/types.go b/types.go new file mode 100644 index 0000000..65ad531 --- /dev/null +++ b/types.go @@ -0,0 +1,497 @@ +package dynmgrm + +import ( + "context" + "database/sql" + "errors" + "fmt" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "math" + "reflect" + "slices" +) + +var ( + ErrValueIsIncompatibleOfStringSlice = errors.New("value is incompatible of string slice") + ErrValueIsIncompatibleOfIntSlice = errors.New("value is incompatible of int slice") + ErrValueIsIncompatibleOfFloat64Slice = errors.New("value is incompatible of float64 slice") + ErrValueIsIncompatibleOfBinarySlice = errors.New("value is incompatible of []byte slice") + ErrCollectionAlreadyContainsItem = errors.New("collection already contains item") + ErrFailedToCast = errors.New("failed to cast") +) + +// SetSupportable are the types that support the Set +type SetSupportable interface { + string | []byte | int | float64 +} + +// compatibility check +var ( + _ gorm.Valuer = (*Set[string])(nil) + _ sql.Scanner = (*Set[string])(nil) +) + +// Set is a DynamoDB set type. +// +// See: https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/HowItWorks.NamingRulesDataTypes.html +type Set[T SetSupportable] []T + +// GormDataType returns the data type for Gorm. +func (s *Set[T]) GormDataType() string { + return "dgsets" +} + +// Scan implements the [sql.Scanner#Scan] +// +// [sql.Scanner#Scan]: https://golang.org/pkg/database/sql/#Scanner +func (s *Set[T]) Scan(value interface{}) error { + if len(*s) != 0 { + return ErrCollectionAlreadyContainsItem + } + if value == nil { + *s = nil + return nil + } + switch (interface{})(s).(type) { + case *Set[int]: + return scanAsIntSet((interface{})(s).(*Set[int]), value) + case *Set[float64]: + return scanAsFloat64Set((interface{})(s).(*Set[float64]), value) + case *Set[string]: + return scanAsStringSet((interface{})(s).(*Set[string]), value) + case *Set[[]byte]: + return scanAsBinarySet((interface{})(s).(*Set[[]byte]), value) + } + return nil +} + +// GormValue implements the [gorm.Valuer] interface. +// +// [gorm.Valuer]: https://pkg.go.dev/gorm.io/gorm#Valuer +func (s Set[T]) GormValue(_ context.Context, db *gorm.DB) clause.Expr { + switch s := (interface{})(s).(type) { + case Set[int]: + av, err := numericSetToAttributeValue(s) + if err != nil { + _ = db.AddError(err) + break + } + return clause.Expr{SQL: "?", Vars: []interface{}{*av}} + case Set[float64]: + av, err := numericSetToAttributeValue(s) + if err != nil { + _ = db.AddError(err) + break + } + return clause.Expr{SQL: "?", Vars: []interface{}{*av}} + case Set[string]: + av, err := stringSetToAttributeValue(s) + if err != nil { + _ = db.AddError(err) + break + } + return clause.Expr{SQL: "?", Vars: []interface{}{*av}} + case Set[[]byte]: + av, err := binarySetToAttributeValue(s) + if err != nil { + _ = db.AddError(err) + break + } + return clause.Expr{SQL: "?", Vars: []interface{}{*av}} + } + return clause.Expr{} +} + +func numericSetToAttributeValue[T Set[int] | Set[float64]](s T) (*types.AttributeValueMemberNS, error) { + return toDocumentAttributeValue[*types.AttributeValueMemberNS](s) +} + +func stringSetToAttributeValue(s Set[string]) (*types.AttributeValueMemberSS, error) { + return toDocumentAttributeValue[*types.AttributeValueMemberSS](s) +} + +func binarySetToAttributeValue(s Set[[]byte]) (*types.AttributeValueMemberBS, error) { + return toDocumentAttributeValue[*types.AttributeValueMemberBS](s) +} + +// scanAsIntSet scans the value as Set[int] +func scanAsIntSet(s *Set[int], value interface{}) error { + sv, ok := value.([]float64) + if !ok { + *s = nil + return ErrValueIsIncompatibleOfIntSlice + } + for _, v := range sv { + if math.Floor(v) != v { + *s = nil + return ErrValueIsIncompatibleOfIntSlice + } + *s = append(*s, int(v)) + } + return nil +} + +// scanAsFloat64Set scans the value as Set[float64] +func scanAsFloat64Set(s *Set[float64], value interface{}) error { + sv, ok := value.([]float64) + if !ok { + *s = nil + return ErrValueIsIncompatibleOfFloat64Slice + } + for _, v := range sv { + *s = append(*s, v) + } + return nil +} + +// scanAsStringSet scans the value as Set[string] +func scanAsStringSet(s *Set[string], value interface{}) error { + sv, ok := value.([]string) + if !ok { + *s = nil + return ErrValueIsIncompatibleOfStringSlice + } + for _, v := range sv { + *s = append(*s, v) + } + return nil +} + +// scanAsBinarySet scans the value as Set[[]byte] +func scanAsBinarySet(s *Set[[]byte], value interface{}) error { + sv, ok := value.([][]byte) + if !ok { + *s = nil + return ErrValueIsIncompatibleOfBinarySlice + } + for _, v := range sv { + *s = append(*s, v) + } + return nil +} + +func isCompatibleWithSet[T SetSupportable](value interface{}) (compatible bool) { + var t T + switch (interface{})(t).(type) { + case string: + compatible = isStringSetCompatible(value) + case int: + compatible = isIntSetCompatible(value) + case float64: + compatible = isFloat64SetCompatible(value) + case []byte: + compatible = isBinarySetCompatible(value) + } + return +} + +func isIntSetCompatible(value interface{}) (compatible bool) { + if _, ok := value.([]int); ok { + compatible = true + return + } + if value, ok := value.([]float64); ok { + compatible = true + for _, v := range value { + if math.Floor(v) == v { + compatible = true + continue + } + compatible = false + return + } + } + return +} + +func isStringSetCompatible(value interface{}) (compatible bool) { + if _, ok := value.([]string); ok { + compatible = true + } + return +} + +func isFloat64SetCompatible(value interface{}) (compatible bool) { + if _, ok := value.([]float64); ok { + compatible = true + } + return +} + +func isBinarySetCompatible(value interface{}) (compatible bool) { + if _, ok := value.([][]byte); ok { + compatible = true + } + return +} + +func newSet[T SetSupportable]() Set[T] { + return Set[T]{} +} + +// compatibility check +var ( + _ gorm.Valuer = (*List)(nil) + _ sql.Scanner = (*List)(nil) +) + +// List is a DynamoDB list type. +// +// See: https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/HowItWorks.NamingRulesDataTypes.html +type List []interface{} + +// GormDataType returns the data type for Gorm. +func (l *List) GormDataType() string { + return "dglist" +} + +// Scan implements the [sql.Scanner#Scan] +// +// [sql.Scanner#Scan]: https://golang.org/pkg/database/sql/#Scanner +func (l *List) Scan(value interface{}) error { + if len(*l) != 0 { + return ErrCollectionAlreadyContainsItem + } + sv, ok := value.([]interface{}) + if !ok { + return errors.Join(ErrFailedToCast, fmt.Errorf("incompatible %T and %T", l, value)) + } + *l = sv + return resolveCollectionsNestedInList(l) +} + +// GormValue implements the [gorm.Valuer] interface. +// +// [gorm.Valuer]: https://pkg.go.dev/gorm.io/gorm#Valuer +func (l List) GormValue(_ context.Context, db *gorm.DB) clause.Expr { + if err := resolveCollectionsNestedInList(&l); err != nil { + _ = db.AddError(err) + return clause.Expr{} + } + av, err := toDocumentAttributeValue[*types.AttributeValueMemberL](l) + if err != nil { + _ = db.AddError(err) + return clause.Expr{} + } + return clause.Expr{SQL: "?", Vars: []interface{}{*av}} +} + +// resolveCollectionsNestedInList resolves nested collection type attribute. +func resolveCollectionsNestedInList(l *List) error { + for i, v := range *l { + if v, ok := v.(map[string]interface{}); ok { + m := Map{} + err := m.Scan(v) + if err != nil { + *l = nil + return err + } + (*l)[i] = m + continue + } + if isCompatibleWithSet[int](v) { + s := newSet[int]() + if err := s.Scan(v); err == nil { + (*l)[i] = s + continue + } + } + if isCompatibleWithSet[float64](v) { + s := newSet[float64]() + if err := s.Scan(v); err == nil { + (*l)[i] = s + continue + } + } + if isCompatibleWithSet[string](v) { + s := newSet[string]() + if err := s.Scan(v); err == nil { + (*l)[i] = s + continue + } + } + if isCompatibleWithSet[[]byte](v) { + s := newSet[[]byte]() + if err := s.Scan(v); err == nil { + (*l)[i] = s + continue + } + } + if v, ok := v.([]interface{}); ok { + il := List{} + err := il.Scan(v) + if err != nil { + *l = nil + return err + } + (*l)[i] = il + } + } + return nil +} + +// compatibility check +var ( + _ gorm.Valuer = (*Map)(nil) + _ sql.Scanner = (*Map)(nil) +) + +// Map is a DynamoDB map type. +// +// See: https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/HowItWorks.NamingRulesDataTypes.html +type Map map[string]interface{} + +// GormDataType returns the data type for Gorm. +func (m Map) GormDataType() string { + return "dgmap" +} + +// Scan implements the [sql.Scanner#Scan] +// +// [sql.Scanner#Scan]: https://golang.org/pkg/database/sql/#Scanner +func (m *Map) Scan(value interface{}) error { + if len(*m) != 0 { + return ErrCollectionAlreadyContainsItem + } + mv, ok := value.(map[string]interface{}) + if !ok { + *m = nil + return ErrFailedToCast + } + *m = mv + return resolveCollectionsNestedInMap(m) +} + +// GormValue implements the [gorm.Valuer] interface. +// +// [gorm.Valuer]: https://pkg.go.dev/gorm.io/gorm#Valuer +func (m Map) GormValue(_ context.Context, db *gorm.DB) clause.Expr { + if err := resolveCollectionsNestedInMap(&m); err != nil { + _ = db.AddError(err) + return clause.Expr{} + } + av, err := toDocumentAttributeValue[*types.AttributeValueMemberM](m) + if err != nil { + _ = db.AddError(err) + return clause.Expr{} + } + return clause.Expr{SQL: "?", Vars: []interface{}{*av}} +} + +// resolveCollectionsNestedInMap resolves nested document type attribute. +func resolveCollectionsNestedInMap(m *Map) error { + for k, v := range *m { + if v, ok := v.(map[string]interface{}); ok { + im := Map{} + err := im.Scan(v) + if err != nil { + *m = nil + return err + } + (*m)[k] = im + continue + } + if isCompatibleWithSet[int](v) { + s := newSet[int]() + if err := s.Scan(v); err == nil { + (*m)[k] = s + continue + } + } + if isCompatibleWithSet[float64](v) { + s := newSet[float64]() + if err := s.Scan(v); err == nil { + (*m)[k] = s + continue + } + } + if isCompatibleWithSet[string](v) { + s := newSet[string]() + if err := s.Scan(v); err == nil { + (*m)[k] = s + continue + } + } + if isCompatibleWithSet[[]byte](v) { + s := newSet[[]byte]() + if err := s.Scan(v); err == nil { + (*m)[k] = s + continue + } + } + if v, ok := v.([]interface{}); ok { + l := List{} + err := l.Scan(v) + if err != nil { + *m = nil + return err + } + (*m)[k] = l + } + } + return nil +} + +// compatibility check +var ( + _ gorm.Valuer = (*TypedList[interface{}])(nil) + _ sql.Scanner = (*TypedList[interface{}])(nil) +) + +// TypedList is a DynamoDB list type with type specification. +// +// See: https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/HowItWorks.NamingRulesDataTypes.html +type TypedList[T any] []T + +// GormDataType returns the data type for Gorm. +func (l *TypedList[T]) GormDataType() string { + return "dgtypedlist" +} + +// Scan implements the [sql.Scanner#Scan] +// +// [sql.Scanner#Scan]: https://golang.org/pkg/database/sql/#Scanner +func (l *TypedList[T]) Scan(value interface{}) error { + if len(*l) != 0 { + return ErrCollectionAlreadyContainsItem + } + sv, ok := value.([]interface{}) + if !ok { + return errors.Join(ErrFailedToCast, fmt.Errorf("incompatible %T and %T", l, value)) + } + *l = slices.Grow(*l, len(sv)) + for _, v := range sv { + mv, ok := v.(map[string]interface{}) + if !ok { + var t T + return errors.Join(ErrFailedToCast, fmt.Errorf("incompatible %T and %T", t, v)) + } + dest := new(T) + rv := reflect.ValueOf(dest) + rt := reflect.TypeOf(*dest) + err := assignMapValueToReflectValue(rt, rv, mv) + if err != nil { + return err + } + *l = append(*l, *dest) + } + return nil +} + +// GormValue implements the [gorm.Valuer] interface. +// +// [gorm.Valuer]: https://pkg.go.dev/gorm.io/gorm#Valuer +func (l TypedList[T]) GormValue(_ context.Context, db *gorm.DB) clause.Expr { + avl := types.AttributeValueMemberL{Value: make([]types.AttributeValue, 0, len(l))} + for _, v := range l { + av, err := toDocumentAttributeValue[*types.AttributeValueMemberM](v) + if err != nil { + _ = db.AddError(err) + return clause.Expr{} + } + avl.Value = append(avl.Value, av) + } + return clause.Expr{SQL: "?", Vars: []interface{}{avl}} +}